diff --git a/astropy/cosmology/parameter/_converter.py b/astropy/cosmology/parameter/_converter.py index f91b0bc881e..59a170e6bd3 100644 --- a/astropy/cosmology/parameter/_converter.py +++ b/astropy/cosmology/parameter/_converter.py @@ -2,19 +2,21 @@ from __future__ import annotations -from collections.abc import Callable -from typing import TYPE_CHECKING, Any, TypeAlias, TypeVar, overload +from typing import TYPE_CHECKING, Any, TypeAlias, overload import astropy.units as u if TYPE_CHECKING: + from collections.abc import Callable + + from numpy.typing import NDArray + from astropy.cosmology import Cosmology, Parameter + _FValidateCallable: TypeAlias = Callable[[Cosmology, Parameter, Any], Any] + __all__: list[str] = [] -# Callable[[Cosmology, Parameter, Any], Any] -_FValidateCallable: TypeAlias = Callable[["Cosmology", "Parameter", Any], Any] -T = TypeVar("T") _REGISTRY_FVALIDATORS: dict[str, _FValidateCallable] = {} @@ -92,14 +94,16 @@ def _validate_with_unit(cosmology: Cosmology, param: Parameter, value: Any) -> A @_register_validator("float") -def _validate_to_float(cosmology: Cosmology, param: Parameter, value: Any) -> Any: +def _validate_to_float(cosmology: Cosmology, param: Parameter, value: Any) -> float: """Parameter value validator with units, and converted to float.""" value = _validate_with_unit(cosmology, param, value) return float(value) @_register_validator("scalar") -def _validate_to_scalar(cosmology: Cosmology, param: Parameter, value: Any) -> Any: +def _validate_to_scalar( + cosmology: Cosmology, param: Parameter, value: Any +) -> NDArray[Any]: """""" value = _validate_with_unit(cosmology, param, value) if not value.isscalar: @@ -108,7 +112,7 @@ def _validate_to_scalar(cosmology: Cosmology, param: Parameter, value: Any) -> A @_register_validator("non-negative") -def _validate_non_negative(cosmology: Cosmology, param: Parameter, value: Any) -> Any: +def _validate_non_negative(cosmology: Cosmology, param: Parameter, value: Any) -> float: """Parameter value validator where value is a positive float.""" value = _validate_to_float(cosmology, param, value) if value < 0.0: diff --git a/astropy/cosmology/parameter/_core.py b/astropy/cosmology/parameter/_core.py index 872fab69c12..2c1d800e2af 100644 --- a/astropy/cosmology/parameter/_core.py +++ b/astropy/cosmology/parameter/_core.py @@ -5,25 +5,26 @@ __all__ = ["Parameter"] import copy -from collections.abc import Callable from dataclasses import KW_ONLY, dataclass, field, fields, is_dataclass, replace from enum import Enum, auto -from typing import TYPE_CHECKING, Any, Generic, TypeAlias, TypeVar, overload +from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload import astropy.units as u from ._converter import _REGISTRY_FVALIDATORS, _register_validator if TYPE_CHECKING: - from collections.abc import Sequence + from collections.abc import Callable, Sequence + from typing import TypeAlias from typing_extensions import Self from astropy.cosmology import Cosmology + _FValidateCallable: TypeAlias = Callable[[Cosmology, "Parameter", Any], "_VT"] -_VT = TypeVar("_VT") # the type of the VParameter value -_FValidateCallable: TypeAlias = Callable[["Cosmology", "Parameter", Any], _VT] + +_VT = TypeVar("_VT") # the type of the VParameter value. Required at runtime. class Sentinel(Enum): @@ -171,14 +172,12 @@ def __set_name__(self, cosmo_cls: type, name: str | None) -> None: # descriptor and property-like methods @overload - def __get__(self, cosmology: None, cosmo_cls: Any) -> Parameter: ... + def __get__(self, cosmology: None, cosmo_cls: Any) -> Self: ... @overload def __get__(self, cosmology: Cosmology, cosmo_cls: Any) -> _VT: ... - def __get__( - self, cosmology: Cosmology | None, cosmo_cls: Any = None - ) -> Parameter | _VT: + def __get__(self, cosmology: Cosmology | None, cosmo_cls: Any = None) -> Self | _VT: # Get from class if cosmology is None: # If the Parameter is being set as part of a dataclass constructor, then we @@ -301,7 +300,7 @@ def clone(self, **kw: Any) -> Self: Parameters ---------- - **kw : Any + **kw : dict, optional Passed to constructor. The current values, eg. ``fvalidate`` are used as the default values, so an empty ``**kw`` is an exact copy.