Skip to content

Commit

Permalink
type hint parameters
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jun 19, 2024
1 parent 33bb428 commit 959e932
Show file tree
Hide file tree
Showing 2 changed files with 84 additions and 29 deletions.
44 changes: 32 additions & 12 deletions astropy/cosmology/parameter/_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,43 @@
from __future__ import annotations

from collections.abc import Callable
from typing import Any
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, overload

import astropy.units as u

__all__ = []
if TYPE_CHECKING:
from astropy.cosmology import Cosmology, Parameter

FValidateCallable = Callable[[object, object, Any], Any]
_REGISTRY_FVALIDATORS: dict[str, FValidateCallable] = {}
__all__: list[str] = []

# Callable[[Cosmology, Parameter, Any], Any]
_FValidateCallable: TypeAlias = Callable[["Cosmology", "Parameter", Any], Any]
T = TypeVar("T")

def _register_validator(key, fvalidate=None):
_REGISTRY_FVALIDATORS: dict[str, _FValidateCallable] = {}


@overload
def _register_validator(
key: str, fvalidate: _FValidateCallable
) -> _FValidateCallable: ...


@overload
def _register_validator(
key: str, fvalidate: None = None
) -> Callable[[_FValidateCallable], _FValidateCallable]: ...


def _register_validator(
key: str, fvalidate: _FValidateCallable | None = None
) -> _FValidateCallable | Callable[[_FValidateCallable], _FValidateCallable]:
"""Decorator to register a new kind of validator function.
Parameters
----------
key : str
fvalidate : callable[[object, object, Any], Any] or None, optional
fvalidate : callable[[Cosmology, Parameter, Any], Any] or None, optional
Value validation function.
Returns
Expand All @@ -38,12 +58,12 @@ def _register_validator(key, fvalidate=None):
return fvalidate

# for use as a decorator
def register(fvalidate):
def register(fvalidate: _FValidateCallable) -> _FValidateCallable:
"""Register validator function.
Parameters
----------
fvalidate : callable[[object, object, Any], Any]
fvalidate : callable[[Cosmology, Parameter, Any], Any]
Validation function.
Returns
Expand All @@ -60,7 +80,7 @@ def register(fvalidate):


@_register_validator("default")
def _validate_with_unit(cosmology, param, value):
def _validate_with_unit(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
"""Default Parameter value validator.
Adds/converts units if Parameter has a unit.
Expand All @@ -72,14 +92,14 @@ def _validate_with_unit(cosmology, param, value):


@_register_validator("float")
def _validate_to_float(cosmology, param, value):
def _validate_to_float(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
"""Parameter value validator with units, and converted to float."""
value = _validate_with_unit(cosmology, param, value)
return float(value)


@_register_validator("scalar")
def _validate_to_scalar(cosmology, param, value):
def _validate_to_scalar(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
""""""
value = _validate_with_unit(cosmology, param, value)
if not value.isscalar:
Expand All @@ -88,7 +108,7 @@ def _validate_to_scalar(cosmology, param, value):


@_register_validator("non-negative")
def _validate_non_negative(cosmology, param, value):
def _validate_non_negative(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
"""Parameter value validator where value is a positive float."""
value = _validate_to_float(cosmology, param, value)
if value < 0.0:
Expand Down
69 changes: 52 additions & 17 deletions astropy/cosmology/parameter/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,34 @@
__all__ = ["Parameter"]

import copy
from collections.abc import Callable
from dataclasses import KW_ONLY, dataclass, field, fields, is_dataclass, replace
from enum import Enum, auto
from typing import TYPE_CHECKING, Any
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, overload

import astropy.units as u

from ._converter import _REGISTRY_FVALIDATORS, FValidateCallable, _register_validator
from ._converter import _REGISTRY_FVALIDATORS, _register_validator

if TYPE_CHECKING:
from collections.abc import Sequence

from typing_extensions import Self

from astropy.cosmology import Cosmology


_VT = TypeVar("_VT") # the type of the VParameter value
FValidateCallableT: TypeAlias = Callable[["Cosmology", "Parameter", Any], _VT]


class Sentinel(Enum):
"""Sentinel values for Parameter fields."""

MISSING = auto()
"""A sentinel value signifying a missing default."""

def __repr__(self):
def __repr__(self) -> str:
return f"<{self.name}>"


Expand All @@ -46,12 +55,12 @@ def __set__(self, obj: Parameter, value: Any) -> None:


@dataclass(frozen=True)
class _FValidateField:
default: FValidateCallable | str = "default"
class _FValidateField(Generic[_VT]):
default: FValidateCallableT[_VT] | str = "default"

def __get__(
self, obj: Parameter | None, objcls: type[Parameter] | None
) -> FValidateCallable | str:
) -> FValidateCallableT[_VT] | str:
if obj is None: # calling `Parameter.fvalidate` from the class
return self.default
return obj._fvalidate # calling `Parameter.fvalidate` from an instance
Expand All @@ -73,7 +82,7 @@ def __set__(self, obj: Parameter, value: Any) -> None:


@dataclass(frozen=True)
class Parameter:
class Parameter(Generic[_VT]):
r"""Cosmological parameter (descriptor).
Should only be used with a :class:`~astropy.cosmology.Cosmology` subclass.
Expand Down Expand Up @@ -130,7 +139,7 @@ class Parameter:
"""Unit equivalencies available when setting the parameter."""

# Setting
fvalidate: _FValidateField = _FValidateField(default="default")
fvalidate: _FValidateField[_VT] = _FValidateField(default="default")
"""Function to validate/convert values when setting the Parameter."""

# Info
Expand All @@ -144,8 +153,8 @@ class Parameter:
"""

def __post_init__(self) -> None:
self._fvalidate_in: FValidateCallable | str
self._fvalidate: FValidateCallable
self._fvalidate_in: FValidateCallableT[_VT] | str
self._fvalidate: FValidateCallableT[_VT]
object.__setattr__(self, "__doc__", self.doc)
# Now setting a dummy attribute name. The cosmology class will call
# `__set_name__`, passing the real attribute name. However, if Parameter is not
Expand All @@ -161,7 +170,15 @@ def __set_name__(self, cosmo_cls: type, name: str | None) -> None:
# -------------------------------------------
# descriptor and property-like methods

def __get__(self, cosmology, cosmo_cls=None):
@overload
def __get__(self, cosmology: None, cosmo_cls: Any) -> Parameter: ...

@overload
def __get__(self, cosmology: Cosmology, cosmo_cls: Any) -> _VT: ...

def __get__(
self, cosmology: Cosmology | None, cosmo_cls: Any = None
) -> Parameter | _VT:
# Get from class
if cosmology is None:
# If the Parameter is being set as part of a dataclass constructor, then we
Expand All @@ -174,10 +191,11 @@ def __get__(self, cosmology, cosmo_cls=None):
):
raise AttributeError
return self

# Get from instance
return getattr(cosmology, self._attr_name)

def __set__(self, cosmology, value):
def __set__(self, cosmology: Cosmology, value: Any) -> None:
"""Allows attribute setting once.
Raises AttributeError subsequently.
Expand Down Expand Up @@ -207,7 +225,7 @@ def __set__(self, cosmology, value):
# -------------------------------------------
# validate value

def validator(self, fvalidate):
def validator(self, fvalidate: FValidateCallableT[_VT]) -> Self:
"""Make new Parameter with custom ``fvalidate``.
Note: ``Parameter.fvalidator`` must be the top-most descriptor decorator.
Expand All @@ -223,7 +241,7 @@ def validator(self, fvalidate):
"""
return self.clone(fvalidate=fvalidate)

def validate(self, cosmology, value):
def validate(self, cosmology: Cosmology, value: Any) -> _VT:
"""Run the validator on this Parameter.
Parameters
Expand All @@ -240,8 +258,25 @@ def validate(self, cosmology, value):
"""
return self._fvalidate(cosmology, self, value)

@overload
@staticmethod
def register_validator(
key: str, fvalidate: FValidateCallableT[_VT]
) -> FValidateCallableT[_VT]: ...

@overload
@staticmethod
def register_validator(
key: str, fvalidate: None = None
) -> Callable[[FValidateCallableT[_VT]], FValidateCallableT[_VT]]: ...

@staticmethod
def register_validator(key, fvalidate=None):
def register_validator(
key: str, fvalidate: FValidateCallableT[_VT] | None = None
) -> (
FValidateCallableT[_VT]
| Callable[[FValidateCallableT[_VT]], FValidateCallableT[_VT]]
):
"""Decorator to register a new kind of validator function.
Parameters
Expand All @@ -261,12 +296,12 @@ def register_validator(key, fvalidate=None):

# -------------------------------------------

def clone(self, **kw):
def clone(self, **kw: Any) -> Self:
"""Clone this `Parameter`, changing any constructor argument.
Parameters
----------
**kw
**kw : Any
Passed to constructor. The current values, eg. ``fvalidate`` are
used as the default values, so an empty ``**kw`` is an exact copy.
Expand Down

0 comments on commit 959e932

Please sign in to comment.