Skip to content

Commit

Permalink
feat: improve type annotations
Browse files Browse the repository at this point in the history
Co-authored-by: Eero Vaher <eero.vaher@fysik.lu.se>
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman and eerovaher committed Jun 29, 2024
1 parent 231786b commit ed54b93
Show file tree
Hide file tree
Showing 2 changed files with 21 additions and 18 deletions.
20 changes: 12 additions & 8 deletions astropy/cosmology/parameter/_converter.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,21 @@

from __future__ import annotations

from collections.abc import Callable
from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeAlias, overload

import astropy.units as u

if TYPE_CHECKING:
from collections.abc import Callable

from numpy.typing import NDArray

from astropy.cosmology import Cosmology, Parameter

_FValidateCallable: TypeAlias = Callable[[Cosmology, Parameter, Any], Any]

__all__: list[str] = []

# Callable[[Cosmology, Parameter, Any], Any]
_FValidateCallable: TypeAlias = Callable[["Cosmology", "Parameter", Any], Any]
T = TypeVar("T")

_REGISTRY_FVALIDATORS: dict[str, _FValidateCallable] = {}

Expand Down Expand Up @@ -92,14 +94,16 @@ def _validate_with_unit(cosmology: Cosmology, param: Parameter, value: Any) -> A


@_register_validator("float")
def _validate_to_float(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
def _validate_to_float(cosmology: Cosmology, param: Parameter, value: Any) -> float:
"""Parameter value validator with units, and converted to float."""
value = _validate_with_unit(cosmology, param, value)
return float(value)


@_register_validator("scalar")
def _validate_to_scalar(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
def _validate_to_scalar(
cosmology: Cosmology, param: Parameter, value: Any
) -> NDArray[Any]:
""""""
value = _validate_with_unit(cosmology, param, value)
if not value.isscalar:
Expand All @@ -108,7 +112,7 @@ def _validate_to_scalar(cosmology: Cosmology, param: Parameter, value: Any) -> A


@_register_validator("non-negative")
def _validate_non_negative(cosmology: Cosmology, param: Parameter, value: Any) -> Any:
def _validate_non_negative(cosmology: Cosmology, param: Parameter, value: Any) -> float:
"""Parameter value validator where value is a positive float."""
value = _validate_to_float(cosmology, param, value)
if value < 0.0:
Expand Down
19 changes: 9 additions & 10 deletions astropy/cosmology/parameter/_core.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,25 +5,26 @@
__all__ = ["Parameter"]

import copy
from collections.abc import Callable
from dataclasses import KW_ONLY, dataclass, field, fields, is_dataclass, replace
from enum import Enum, auto
from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, overload
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload

import astropy.units as u

from ._converter import _REGISTRY_FVALIDATORS, _register_validator

if TYPE_CHECKING:
from collections.abc import Sequence
from collections.abc import Callable, Sequence
from typing import TypeAlias

from typing_extensions import Self

from astropy.cosmology import Cosmology

_FValidateCallable: TypeAlias = Callable[[Cosmology, "Parameter", Any], "_VT"]

_VT = TypeVar("_VT") # the type of the VParameter value
_FValidateCallable: TypeAlias = Callable[["Cosmology", "Parameter", Any], _VT]

_VT = TypeVar("_VT") # the type of the VParameter value. Required at runtime.


class Sentinel(Enum):
Expand Down Expand Up @@ -171,14 +172,12 @@ def __set_name__(self, cosmo_cls: type, name: str | None) -> None:
# descriptor and property-like methods

@overload
def __get__(self, cosmology: None, cosmo_cls: Any) -> Parameter: ...
def __get__(self, cosmology: None, cosmo_cls: Any) -> Self: ...

@overload
def __get__(self, cosmology: Cosmology, cosmo_cls: Any) -> _VT: ...

def __get__(
self, cosmology: Cosmology | None, cosmo_cls: Any = None
) -> Parameter | _VT:
def __get__(self, cosmology: Cosmology | None, cosmo_cls: Any = None) -> Self | _VT:
# Get from class
if cosmology is None:
# If the Parameter is being set as part of a dataclass constructor, then we
Expand Down Expand Up @@ -301,7 +300,7 @@ def clone(self, **kw: Any) -> Self:
Parameters
----------
**kw : Any
**kw : dict, optional
Passed to constructor. The current values, eg. ``fvalidate`` are
used as the default values, so an empty ``**kw`` is an exact copy.
Expand Down

0 comments on commit ed54b93

Please sign in to comment.