Skip to content

Commit

Permalink
refactor: as cached_property
Browse files Browse the repository at this point in the history
Signed-off-by: nstarman <nstarman@users.noreply.github.com>
  • Loading branch information
nstarman committed Jul 11, 2024
1 parent 04bc2fd commit 963440c
Show file tree
Hide file tree
Showing 3 changed files with 42 additions and 95 deletions.
55 changes: 1 addition & 54 deletions astropy/cosmology/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from collections.abc import Callable
from dataclasses import Field
from numbers import Number
from typing import TYPE_CHECKING, Any, Generic, TypeVar, overload
from typing import TYPE_CHECKING, Any, TypeVar

import numpy as np

Expand All @@ -20,7 +20,6 @@

if TYPE_CHECKING:
from astropy.cosmology import Parameter
from astropy.cosmology.core import Cosmology


_F = TypeVar("_F", bound=Callable[..., Any])
Expand Down Expand Up @@ -143,55 +142,3 @@ def _depr_kws(func: _F, /, kws: tuple[str, ...], since: str) -> _F:
wrapper = _depr_kws_wrap(func, kws, since)
functools.update_wrapper(wrapper, func)
return wrapper


class CachedInDictPropertyDescriptor(Generic[R]):
"""Descriptor for a property that is cached in the instance's dictionary.
Note that this is a non-data descriptor, not NOT a data descriptor, so
after the property is accessed and cached with the same key as the
property, the instance's dictionary will have the property value
directly, and the descriptor will not be called again, until the
property is deleted from the instance's dictionary. See
https://docs.python.org/3/howto/descriptor.html#descriptor-protocol.
"""

# __slots__ = ("fget", "name") # TODO: when __doc__ is supported by __slots__

def __init__(self, fget: Callable[[Cosmology], R]) -> None:
self.fget = fget
self.__doc__ = fget.__doc__

def __set_name__(self, cosmo_cls: type[Cosmology], name: str) -> None:
self.name: str = name

@overload
def __get__(
self, cosmo: None, cosmo_cls: Any
) -> CachedInDictPropertyDescriptor[R]: ...

@overload
def __get__(self, cosmo: Cosmology, cosmo_cls: Any) -> R: ...

def __get__(
self, cosmo: Cosmology | None, cosmo_cls: type[Cosmology] | None
) -> R | CachedInDictPropertyDescriptor[R]:
# Accessed from the class, return the descriptor itself
if cosmo is None:
return self

# If the property is not in the instance's dictionary, calculate and store it.
if self.name not in cosmo.__dict__:
cosmo.__dict__[self.name] = self.fget(cosmo)

# Return the property value from the instance's dictionary
# This is only called once, thereafter the property value is accessed directly
# from the instance's dictionary.
return cosmo.__dict__[self.name]


def cached_on_dict_property(
fget: Callable[[Cosmology], R],
) -> CachedInDictPropertyDescriptor[R]:
"""Descriptor for a property that is cached in the instance's dictionary."""
return CachedInDictPropertyDescriptor(fget)
44 changes: 22 additions & 22 deletions astropy/cosmology/flrw/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import warnings
from abc import abstractmethod
from dataclasses import field
from functools import cached_property
from inspect import signature
from math import exp, floor, log, pi, sqrt
from numbers import Number
Expand All @@ -21,7 +22,6 @@
import astropy.units as u
from astropy.cosmology._utils import (
aszarr,
cached_on_dict_property,
deprecated_keywords,
vectorize_redshift_method,
)
Expand Down Expand Up @@ -312,69 +312,69 @@ def m_nu(self, param, value):
# properties

@property
def is_flat(self):
def is_flat(self) -> bool:
"""Return bool; `True` if the cosmology is flat."""
return bool((self.Ok0 == 0.0) and (self.Otot0 == 1.0))

@property
def Otot0(self):
def Otot0(self) -> float:
"""Omega total; the total density/critical density at z=0."""
return self._Om0 + self.Ogamma0 + self.Onu0 + self._Ode0 + self.Ok0

@cached_on_dict_property
def Odm0(self):
@cached_property
def Odm0(self) -> float | None:
"""Omega dark matter; dark matter density/critical density at z=0."""
return None if self.Ob0 is None else (self.Om0 - self.Ob0)

@cached_on_dict_property
def Ok0(self):
@cached_property
def Ok0(self) -> float:
"""Omega curvature; the effective curvature density/critical density at z=0."""
return 1.0 - self.Om0 - self.Ode0 - self.Ogamma0 - self.Onu0

@cached_on_dict_property
def Tnu0(self):
@cached_property
def Tnu0(self) -> u.Quantity:
"""Temperature of the neutrino background as |Quantity| at z=0."""
# The constant in front is (4/11)^1/3 -- see any cosmology book for an
# explanation -- for example, Weinberg 'Cosmology' p 154 eq (3.1.21).
return 0.7137658555036082 * self.Tcmb0

@property
def has_massive_nu(self):
def has_massive_nu(self) -> bool:
"""Does this cosmology have at least one massive neutrino species?"""
if self.Tnu0.value == 0:
return False
return self._massivenu

@cached_on_dict_property
def h(self):
@cached_property
def h(self) -> float:
"""Dimensionless Hubble constant: h = H_0 / 100 [km/sec/Mpc]."""
return self.H0.value / 100.0

@cached_on_dict_property
def hubble_time(self):
@cached_property
def hubble_time(self) -> u.Quantity:
"""Hubble time as `~astropy.units.Quantity`."""
return (_sec_to_Gyr / (self.H0.value * _H0units_to_invs)) << u.Gyr

@cached_on_dict_property
def hubble_distance(self):
@cached_property
def hubble_distance(self) -> u.Quantity:
"""Hubble distance as `~astropy.units.Quantity`."""
return (const.c / self.H0).to(u.Mpc)

@cached_on_dict_property
def critical_density0(self):
@cached_property
def critical_density0(self) -> u.Quantity:
"""Critical density as `~astropy.units.Quantity` at z=0."""
return (
_critdens_const * (self.H0.value * _H0units_to_invs) ** 2
) << u.g / u.cm**3

@cached_on_dict_property
def Ogamma0(self):
@cached_property
def Ogamma0(self) -> float:
"""Omega gamma; the density/critical density of photons at z=0."""
# photon density from Tcmb
return _a_B_c2 * self.Tcmb0.value**4 / self.critical_density0.value

@cached_on_dict_property
def Onu0(self):
@cached_property
def Onu0(self) -> float:
"""Omega nu; the density/critical density of neutrinos at z=0."""
if self._massivenu: # (`_massivenu` set in `m_nu`)
return self.Ogamma0 * self.nu_relative_density(0)
Expand Down
38 changes: 19 additions & 19 deletions astropy/cosmology/flrw/tests/test_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

import abc
import copy
from functools import cached_property

import numpy as np
import pytest
Expand All @@ -18,7 +19,6 @@
Parameter,
Planck18,
)
from astropy.cosmology._utils import CachedInDictPropertyDescriptor
from astropy.cosmology.core import _COSMOLOGY_CLASSES, dataclass_decorator
from astropy.cosmology.flrw.base import _a_B_c2, _critdens_const, _H0units_to_invs, quad
from astropy.cosmology.parameter._core import MISSING
Expand Down Expand Up @@ -555,9 +555,9 @@ def test_init_Tcmb0_zeroing(self, cosmo_cls, ba):
# Properties

def test_Odm0(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``Odm0``."""
"""Test ``cached_property`` ``Odm0``."""
# on the class
assert isinstance(cosmo_cls.Odm0, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.Odm0, cached_property)

# on the instance
# Odm0 can be None, if Ob0 is None. Otherwise DM = matter - baryons.
Expand All @@ -567,9 +567,9 @@ def test_Odm0(self, cosmo_cls, cosmo):
assert np.allclose(cosmo.Odm0, cosmo.Om0 - cosmo.Ob0)

def test_Ok0(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``Ok0``."""
"""Test ``cached_property`` ``Ok0``."""
# on the class
assert isinstance(cosmo_cls.Ok0, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.Ok0, cached_property)

# on the instance
assert np.allclose(
Expand All @@ -587,9 +587,9 @@ def test_is_flat(self, cosmo_cls, cosmo):
assert cosmo.is_flat is bool((cosmo.Ok0 == 0.0) and (cosmo.Otot0 == 1.0))

def test_Tnu0(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``Tnu0``."""
"""Test ``cached_property`` ``Tnu0``."""
# on the class
assert isinstance(cosmo_cls.Tnu0, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.Tnu0, cached_property)

# on the instance
assert cosmo.Tnu0.unit == u.K
Expand All @@ -608,33 +608,33 @@ def test_has_massive_nu(self, cosmo_cls, cosmo):
assert cosmo.has_massive_nu is cosmo._massivenu

def test_h(self, cosmo_cls, cosmo):
"""Test property ``h``."""
"""Test ``cached_property`` ``h``."""
# on the class
assert isinstance(cosmo_cls.h, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.h, cached_property)

# on the instance
assert np.allclose(cosmo.h, cosmo.H0.value / 100.0)

def test_hubble_time(self, cosmo_cls, cosmo):
"""Test property ``hubble_time``."""
"""Test ``cached_property`` ``hubble_time``."""
# on the class
assert isinstance(cosmo_cls.hubble_time, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.hubble_time, cached_property)

# on the instance
assert u.allclose(cosmo.hubble_time, (1 / cosmo.H0) << u.Gyr)

def test_hubble_distance(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``hubble_distance``."""
"""Test ``cached_property`` ``hubble_distance``."""
# on the class
assert isinstance(cosmo_cls.hubble_distance, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.hubble_distance, cached_property)

# on the instance
assert cosmo.hubble_distance == (const.c / cosmo._H0).to(u.Mpc)

def test_critical_density0(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``critical_density0``."""
"""Test ``cached_property`` ``critical_density0``."""
# on the class
assert isinstance(cosmo_cls.critical_density0, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.critical_density0, cached_property)

# on the instance
assert cosmo.critical_density0.unit == u.g / u.cm**3
Expand All @@ -643,9 +643,9 @@ def test_critical_density0(self, cosmo_cls, cosmo):
assert cosmo.critical_density0.value == cd0value

def test_Ogamma0(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``Ogamma0``."""
"""Test ``cached_property`` ``Ogamma0``."""
# on the class
assert isinstance(cosmo_cls.Ogamma0, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.Ogamma0, cached_property)

# on the instance
# Ogamma cor \propto T^4/rhocrit
Expand All @@ -656,9 +656,9 @@ def test_Ogamma0(self, cosmo_cls, cosmo):
assert cosmo.Ogamma0 == 0

def test_Onu0(self, cosmo_cls, cosmo):
"""Test CachedInDictPropertyDescriptor ``Onu0``."""
"""Test ``cached_property`` ``Onu0``."""
# on the class
assert isinstance(cosmo_cls.Onu0, CachedInDictPropertyDescriptor)
assert isinstance(cosmo_cls.Onu0, cached_property)

# on the instance
# neutrino temperature <= photon temperature since the neutrinos
Expand Down

0 comments on commit 963440c

Please sign in to comment.