Skip to content

Commit

Permalink
refactor(src/openscm_calibration/type_hinting.py): remove :mod:`opens…
Browse files Browse the repository at this point in the history
…cm_calibration.type_hinting`

This was a hack, it is much better using https://numpy.org/devdocs/reference/arrays.scalars.html#numpy.number
  • Loading branch information
znicholls committed May 6, 2023
1 parent 5b841d8 commit 870b6b0
Show file tree
Hide file tree
Showing 9 changed files with 35 additions and 87 deletions.
1 change: 0 additions & 1 deletion docs/source/api/openscm_calibration.rst
Original file line number Diff line number Diff line change
Expand Up @@ -37,5 +37,4 @@
openscm_calibration.scipy_plotting
openscm_calibration.scmdata_utils
openscm_calibration.store
openscm_calibration.type_hints

31 changes: 0 additions & 31 deletions docs/source/api/openscm_calibration.type_hints.rst

This file was deleted.

10 changes: 5 additions & 5 deletions src/openscm_calibration/emcee_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import emcee
import matplotlib

from openscm_calibration.type_hints import NPArrayFloatOrInt

try:
import corner

Expand All @@ -42,7 +40,9 @@ def plot_chains( # noqa: PLR0913
parameter_order: list[str],
neg_log_likelihood_name: str,
axes_d: dict[str, matplotlib.axes.Axes],
get_neg_log_likelihood_ylim: Callable[[NPArrayFloatOrInt], tuple[float, float]]
get_neg_log_likelihood_ylim: Callable[
[np.typing.NDArray[np.floating[Any] | np.integer[Any]]], tuple[float, float]
]
| None = None,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -106,7 +106,7 @@ def plot_chains( # noqa: PLR0913

def plot_parameter_chains( # noqa: PLR0913
ax: matplotlib.Axes.axes,
chain_values: NPArrayFloatOrInt,
chain_values: np.typing.NDArray[np.number[Any]],
burnin: int,
alpha_chain: float = 0.3,
linewidth: float = 0.5,
Expand Down Expand Up @@ -173,7 +173,7 @@ def plot_parameter_chains( # noqa: PLR0913


def get_neg_log_likelihood_ylim_default(
neg_ll_values: NPArrayFloatOrInt,
neg_ll_values: np.typing.NDArray[np.floating[Any] | np.integer[Any]],
median_scaling: float = 1.5,
max_scaling: float = 2.0,
) -> tuple[float, float]:
Expand Down
13 changes: 9 additions & 4 deletions src/openscm_calibration/emcee_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,11 +10,11 @@
if TYPE_CHECKING:
# See here for explanation of this pattern and why we don't need quotes
# below https://docs.python.org/3/library/typing.html#constant
from typing import Any

import emcee.backends
import numpy.typing as nptype

from openscm_calibration.type_hints import NPAnyFloat, NPArrayFloatOrInt


def get_acceptance_fractions(
chains: nptype.NDArray[np.float_],
Expand Down Expand Up @@ -53,7 +53,7 @@ def get_autocorrelation_info(
thin: int = 1,
autocorr_tol: int = 0,
convergence_ratio: float = 50,
) -> dict[str, float | int | bool | nptype.NDArray[NPAnyFloat]]:
) -> dict[str, float | int | bool | nptype.NDArray[np.floating[Any]]]:
"""
Get info about autocorrelation in chains
Expand Down Expand Up @@ -114,7 +114,7 @@ def get_labelled_chain_data(
neg_log_likelihood_name: str | None = None,
burnin: int = 0,
thin: int = 0,
) -> dict[str, NPArrayFloatOrInt]:
) -> dict[str, np.typing.NDArray[np.floating[Any] | np.integer[Any]]]:
"""
Get labelled chain data
Expand All @@ -135,6 +135,11 @@ def get_labelled_chain_data(
thin
Thinning to use when sampling the chains
Returns
-------
Chain data, labelled with parameter names and, if requested,
``neg_log_likelihood_name``
"""
all_samples = inp.get_chain(discard=burnin, thin=thin)

Expand Down
6 changes: 3 additions & 3 deletions src/openscm_calibration/minimize.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from openscm_calibration.store import OptResStore

if TYPE_CHECKING:
from openscm_calibration.type_hints import NPArrayFloatOrInt
from typing import Any


class SupportsCostCalculation(Protocol):
Expand Down Expand Up @@ -41,7 +41,7 @@ class SupportsModelRun(Protocol):

def run_model(
self,
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
) -> scmdata.run.BaseScmRun:
"""
Calculate cost function
Expand All @@ -58,7 +58,7 @@ def run_model(


def to_minimize_full(
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
cost_calculator: SupportsCostCalculation,
model_runner: SupportsModelRun,
store: OptResStore | None = None,
Expand Down
10 changes: 4 additions & 6 deletions src/openscm_calibration/model_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,6 @@
import pint
import scmdata.run

from openscm_calibration.type_hints import NPAnyFloat, NPArrayFloatOrInt


class XToNamedPintConvertor(Protocol):
"""
Expand All @@ -30,7 +28,7 @@ class XToNamedPintConvertor(Protocol):

def __call__(
self,
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
) -> dict[str, pint.Quantity[np.float_] | np.float_]:
"""
Convert x to pint quantities
Expand Down Expand Up @@ -136,7 +134,7 @@ def from_parameters(

def run_model(
self,
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
) -> scmdata.run.BaseScmRun:
"""
Run the model
Expand All @@ -160,10 +158,10 @@ def run_model(


def x_and_parameters_to_named_with_units(
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
params: Iterable[tuple[str, str | pint.Unit | None]],
get_unit_registry: Callable[[], pint.UnitRegistry] | None = None,
) -> dict[str, pint.Quantity[NPAnyFloat] | NPAnyFloat]:
) -> dict[str, pint.Quantity[np.floating[Any]] | np.floating[Any]]:
"""
Convert x array and parameters to a dictionary and add units
Expand Down
15 changes: 7 additions & 8 deletions src/openscm_calibration/scipy_plotting.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,6 @@
import tqdm

from openscm_calibration.store import OptResStore
from openscm_calibration.type_hints import NPArrayFloatOrInt


logger: logging.Logger = logging.getLogger(__name__)
Expand All @@ -43,7 +42,7 @@ class SupportsScipyOptCallback(Protocol):

def callback_minimize(
self,
xk: NPArrayFloatOrInt,
xk: np.typing.NDArray[np.number[Any]],
) -> None:
"""
Get cost of parameter vector
Expand All @@ -58,7 +57,7 @@ def callback_minimize(

def callback_differential_evolution(
self,
xk: NPArrayFloatOrInt,
xk: np.typing.NDArray[np.number[Any]],
convergence: float | None = None,
) -> None:
"""
Expand Down Expand Up @@ -290,7 +289,7 @@ class OptPlotter:

def callback_minimize(
self,
xk: NPArrayFloatOrInt,
xk: np.typing.NDArray[np.number[Any]],
) -> None:
"""
Update the plots
Expand All @@ -307,7 +306,7 @@ def callback_minimize(

def callback_differential_evolution(
self,
xk: NPArrayFloatOrInt,
xk: np.typing.NDArray[np.number[Any]],
convergence: float | None = None,
) -> None:
"""
Expand Down Expand Up @@ -607,7 +606,7 @@ def get_ymax_default(

def plot_parameters(
axes: dict[str, matplotlib.axes.Axes],
para_vals: dict[str, NPArrayFloatOrInt],
para_vals: dict[str, np.typing.NDArray[np.number[Any]]],
alpha: float = 0.7,
**kwargs: Any,
) -> None:
Expand Down Expand Up @@ -883,7 +882,7 @@ class CallbackProxy:

def callback_minimize(
self,
xk: NPArrayFloatOrInt,
xk: np.typing.NDArray[np.number[Any]],
) -> None:
"""
Update the plots
Expand All @@ -901,7 +900,7 @@ def callback_minimize(

def callback_differential_evolution(
self,
xk: NPArrayFloatOrInt,
xk: np.typing.NDArray[np.number[Any]],
convergence: float | None = None,
) -> None:
"""
Expand Down
16 changes: 7 additions & 9 deletions src/openscm_calibration/store.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,6 @@
import attr
import scmdata.run

from openscm_calibration.type_hints import NPArrayFloatOrInt


class SupportsListLikeHandling(Protocol):
"""
Expand Down Expand Up @@ -146,7 +144,7 @@ class OptResStore:
)
"""Costs of runs"""

x_samples: MutableSequence[None | NPArrayFloatOrInt] = field(
x_samples: MutableSequence[None | np.typing.NDArray[np.number[Any]]] = field(
validator=[_all_none_to_start, _same_length_as_res]
)
"""x vectors sampled"""
Expand Down Expand Up @@ -240,7 +238,7 @@ def set_result_cost_x(
self,
res: None | scmdata.run.BaseScmRun,
cost: float,
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
idx: int,
) -> None:
"""
Expand Down Expand Up @@ -278,7 +276,7 @@ def append_result_cost_x(
self,
res: scmdata.run.BaseScmRun,
cost: float,
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
) -> None:
"""
Append result, cost and x from a successful run to the results
Expand Down Expand Up @@ -308,7 +306,7 @@ def append_result_cost_x(
def note_failed_run(
self,
cost: float,
x: NPArrayFloatOrInt,
x: np.typing.NDArray[np.number[Any]],
) -> None:
"""
Note that a run failed
Expand Down Expand Up @@ -338,7 +336,7 @@ def get_costs_xsamples_res(
self,
) -> tuple[
tuple[float, ...],
tuple[NPArrayFloatOrInt, ...],
tuple[np.typing.NDArray[np.number[Any]], ...],
tuple[scmdata.run.BaseScmRun, ...],
]:
"""
Expand Down Expand Up @@ -370,7 +368,7 @@ def get_costs_xsamples_res(

# Help out type hinting
costs: tuple[float, ...] = tmp[0]
xs_out: tuple[NPArrayFloatOrInt, ...] = tmp[1]
xs_out: tuple[np.typing.NDArray[np.number[Any]], ...] = tmp[1]
ress: tuple[scmdata.run.BaseScmRun, ...] = tmp[2]

out = (costs, xs_out, ress)
Expand All @@ -381,7 +379,7 @@ def get_costs_labelled_xsamples_res(
self,
) -> tuple[
tuple[float, ...],
dict[str, NPArrayFloatOrInt],
dict[str, np.typing.NDArray[np.number[Any]]],
tuple[scmdata.run.BaseScmRun, ...],
]:
"""
Expand Down
20 changes: 0 additions & 20 deletions src/openscm_calibration/type_hints.py

This file was deleted.

0 comments on commit 870b6b0

Please sign in to comment.