From 959e9320893e1de61e892280b59917c07a0d3115 Mon Sep 17 00:00:00 2001 From: nstarman Date: Thu, 11 Jan 2024 20:31:10 -0500 Subject: [PATCH] type hint parameters Signed-off-by: nstarman --- astropy/cosmology/parameter/_converter.py | 44 +++++++++++---- astropy/cosmology/parameter/_core.py | 69 +++++++++++++++++------ 2 files changed, 84 insertions(+), 29 deletions(-) diff --git a/astropy/cosmology/parameter/_converter.py b/astropy/cosmology/parameter/_converter.py index bc594ef4c1a..f91b0bc881e 100644 --- a/astropy/cosmology/parameter/_converter.py +++ b/astropy/cosmology/parameter/_converter.py @@ -3,23 +3,43 @@ from __future__ import annotations from collections.abc import Callable -from typing import Any +from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, overload import astropy.units as u -__all__ = [] +if TYPE_CHECKING: + from astropy.cosmology import Cosmology, Parameter -FValidateCallable = Callable[[object, object, Any], Any] -_REGISTRY_FVALIDATORS: dict[str, FValidateCallable] = {} +__all__: list[str] = [] +# Callable[[Cosmology, Parameter, Any], Any] +_FValidateCallable: TypeAlias = Callable[["Cosmology", "Parameter", Any], Any] +T = TypeVar("T") -def _register_validator(key, fvalidate=None): +_REGISTRY_FVALIDATORS: dict[str, _FValidateCallable] = {} + + +@overload +def _register_validator( + key: str, fvalidate: _FValidateCallable +) -> _FValidateCallable: ... + + +@overload +def _register_validator( + key: str, fvalidate: None = None +) -> Callable[[_FValidateCallable], _FValidateCallable]: ... + + +def _register_validator( + key: str, fvalidate: _FValidateCallable | None = None +) -> _FValidateCallable | Callable[[_FValidateCallable], _FValidateCallable]: """Decorator to register a new kind of validator function. Parameters ---------- key : str - fvalidate : callable[[object, object, Any], Any] or None, optional + fvalidate : callable[[Cosmology, Parameter, Any], Any] or None, optional Value validation function. Returns @@ -38,12 +58,12 @@ def _register_validator(key, fvalidate=None): return fvalidate # for use as a decorator - def register(fvalidate): + def register(fvalidate: _FValidateCallable) -> _FValidateCallable: """Register validator function. Parameters ---------- - fvalidate : callable[[object, object, Any], Any] + fvalidate : callable[[Cosmology, Parameter, Any], Any] Validation function. Returns @@ -60,7 +80,7 @@ def register(fvalidate): @_register_validator("default") -def _validate_with_unit(cosmology, param, value): +def _validate_with_unit(cosmology: Cosmology, param: Parameter, value: Any) -> Any: """Default Parameter value validator. Adds/converts units if Parameter has a unit. @@ -72,14 +92,14 @@ def _validate_with_unit(cosmology, param, value): @_register_validator("float") -def _validate_to_float(cosmology, param, value): +def _validate_to_float(cosmology: Cosmology, param: Parameter, value: Any) -> Any: """Parameter value validator with units, and converted to float.""" value = _validate_with_unit(cosmology, param, value) return float(value) @_register_validator("scalar") -def _validate_to_scalar(cosmology, param, value): +def _validate_to_scalar(cosmology: Cosmology, param: Parameter, value: Any) -> Any: """""" value = _validate_with_unit(cosmology, param, value) if not value.isscalar: @@ -88,7 +108,7 @@ def _validate_to_scalar(cosmology, param, value): @_register_validator("non-negative") -def _validate_non_negative(cosmology, param, value): +def _validate_non_negative(cosmology: Cosmology, param: Parameter, value: Any) -> Any: """Parameter value validator where value is a positive float.""" value = _validate_to_float(cosmology, param, value) if value < 0.0: diff --git a/astropy/cosmology/parameter/_core.py b/astropy/cosmology/parameter/_core.py index 852dd775226..8d227fc9b71 100644 --- a/astropy/cosmology/parameter/_core.py +++ b/astropy/cosmology/parameter/_core.py @@ -5,17 +5,26 @@ __all__ = ["Parameter"] import copy +from collections.abc import Callable from dataclasses import KW_ONLY, dataclass, field, fields, is_dataclass, replace from enum import Enum, auto -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, overload import astropy.units as u -from ._converter import _REGISTRY_FVALIDATORS, FValidateCallable, _register_validator +from ._converter import _REGISTRY_FVALIDATORS, _register_validator if TYPE_CHECKING: from collections.abc import Sequence + from typing_extensions import Self + + from astropy.cosmology import Cosmology + + +_VT = TypeVar("_VT") # the type of the VParameter value +FValidateCallableT: TypeAlias = Callable[["Cosmology", "Parameter", Any], _VT] + class Sentinel(Enum): """Sentinel values for Parameter fields.""" @@ -23,7 +32,7 @@ class Sentinel(Enum): MISSING = auto() """A sentinel value signifying a missing default.""" - def __repr__(self): + def __repr__(self) -> str: return f"<{self.name}>" @@ -46,12 +55,12 @@ def __set__(self, obj: Parameter, value: Any) -> None: @dataclass(frozen=True) -class _FValidateField: - default: FValidateCallable | str = "default" +class _FValidateField(Generic[_VT]): + default: FValidateCallableT[_VT] | str = "default" def __get__( self, obj: Parameter | None, objcls: type[Parameter] | None - ) -> FValidateCallable | str: + ) -> FValidateCallableT[_VT] | str: if obj is None: # calling `Parameter.fvalidate` from the class return self.default return obj._fvalidate # calling `Parameter.fvalidate` from an instance @@ -73,7 +82,7 @@ def __set__(self, obj: Parameter, value: Any) -> None: @dataclass(frozen=True) -class Parameter: +class Parameter(Generic[_VT]): r"""Cosmological parameter (descriptor). Should only be used with a :class:`~astropy.cosmology.Cosmology` subclass. @@ -130,7 +139,7 @@ class Parameter: """Unit equivalencies available when setting the parameter.""" # Setting - fvalidate: _FValidateField = _FValidateField(default="default") + fvalidate: _FValidateField[_VT] = _FValidateField(default="default") """Function to validate/convert values when setting the Parameter.""" # Info @@ -144,8 +153,8 @@ class Parameter: """ def __post_init__(self) -> None: - self._fvalidate_in: FValidateCallable | str - self._fvalidate: FValidateCallable + self._fvalidate_in: FValidateCallableT[_VT] | str + self._fvalidate: FValidateCallableT[_VT] object.__setattr__(self, "__doc__", self.doc) # Now setting a dummy attribute name. The cosmology class will call # `__set_name__`, passing the real attribute name. However, if Parameter is not @@ -161,7 +170,15 @@ def __set_name__(self, cosmo_cls: type, name: str | None) -> None: # ------------------------------------------- # descriptor and property-like methods - def __get__(self, cosmology, cosmo_cls=None): + @overload + def __get__(self, cosmology: None, cosmo_cls: Any) -> Parameter: ... + + @overload + def __get__(self, cosmology: Cosmology, cosmo_cls: Any) -> _VT: ... + + def __get__( + self, cosmology: Cosmology | None, cosmo_cls: Any = None + ) -> Parameter | _VT: # Get from class if cosmology is None: # If the Parameter is being set as part of a dataclass constructor, then we @@ -174,10 +191,11 @@ def __get__(self, cosmology, cosmo_cls=None): ): raise AttributeError return self + # Get from instance return getattr(cosmology, self._attr_name) - def __set__(self, cosmology, value): + def __set__(self, cosmology: Cosmology, value: Any) -> None: """Allows attribute setting once. Raises AttributeError subsequently. @@ -207,7 +225,7 @@ def __set__(self, cosmology, value): # ------------------------------------------- # validate value - def validator(self, fvalidate): + def validator(self, fvalidate: FValidateCallableT[_VT]) -> Self: """Make new Parameter with custom ``fvalidate``. Note: ``Parameter.fvalidator`` must be the top-most descriptor decorator. @@ -223,7 +241,7 @@ def validator(self, fvalidate): """ return self.clone(fvalidate=fvalidate) - def validate(self, cosmology, value): + def validate(self, cosmology: Cosmology, value: Any) -> _VT: """Run the validator on this Parameter. Parameters @@ -240,8 +258,25 @@ def validate(self, cosmology, value): """ return self._fvalidate(cosmology, self, value) + @overload + @staticmethod + def register_validator( + key: str, fvalidate: FValidateCallableT[_VT] + ) -> FValidateCallableT[_VT]: ... + + @overload + @staticmethod + def register_validator( + key: str, fvalidate: None = None + ) -> Callable[[FValidateCallableT[_VT]], FValidateCallableT[_VT]]: ... + @staticmethod - def register_validator(key, fvalidate=None): + def register_validator( + key: str, fvalidate: FValidateCallableT[_VT] | None = None + ) -> ( + FValidateCallableT[_VT] + | Callable[[FValidateCallableT[_VT]], FValidateCallableT[_VT]] + ): """Decorator to register a new kind of validator function. Parameters @@ -261,12 +296,12 @@ def register_validator(key, fvalidate=None): # ------------------------------------------- - def clone(self, **kw): + def clone(self, **kw: Any) -> Self: """Clone this `Parameter`, changing any constructor argument. Parameters ---------- - **kw + **kw : Any Passed to constructor. The current values, eg. ``fvalidate`` are used as the default values, so an empty ``**kw`` is an exact copy.