Skip to content

Commit

Permalink
Merge pull request #122 from steven-murray/release
Browse files Browse the repository at this point in the history
3.3.4
  • Loading branch information
steven-murray committed Jan 8, 2021
2 parents 0d5874b + 96798a7 commit ce6520b
Show file tree
Hide file tree
Showing 11 changed files with 101 additions and 25 deletions.
2 changes: 1 addition & 1 deletion .bumpversion.cfg
@@ -1,5 +1,5 @@
[bumpversion]
current_version = 3.3.3
current_version = 3.3.4
commit = False
tag = False

Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/tag-release.yaml
Expand Up @@ -14,4 +14,4 @@ jobs:
uses: tvdias/github-tagger@v0.0.2
with:
repo-token: "${{ secrets.BUMP_VERSION }}"
tag: "${{ version }}"
tag: "${{ env.version }}"
11 changes: 11 additions & 0 deletions CHANGELOG.rst
Expand Up @@ -4,6 +4,17 @@ Releases
dev-version
----------------------

v3.3.4 [08 Jan 2021]
----------------------

**Bugfixes**

- Added ``validate()`` method that is automatically called after ``__init__()`` and
``update()`` and also (optionally) on every ``parameter`` update. This can cross-validate different
inputs. **NOTE**: if you currently have code that sets parameters directly, eg ``mf.z=2.0``,
you should update to using the ``update()`` method, as this allows multiple parameters
to be set, and then a call to ``validate()``.

v3.3.3 [21 Dec 2020]
----------------------
**Bugfixes**
Expand Down
2 changes: 1 addition & 1 deletion VERSION
@@ -1 +1 @@
3.3.3
3.3.4
14 changes: 13 additions & 1 deletion src/hmf/_internals/_cache.py
Expand Up @@ -229,7 +229,7 @@ def _set_property(self, val):
and not isinstance(val, dict)
and val is not None
):
raise ValueError("%s must be a dictionary" % name)
raise ValueError(f"{name} must be a dictionary")

# Locations of indexes
recalc = hidden_loc(self, "recalc")
Expand Down Expand Up @@ -280,6 +280,18 @@ def _set_property(self, val):
for pr in getattr(self, recalc_papr)[name]:
delattr(self, pr)

if not doset and self._validate:
if self._validate_every_param_set:
self.validate()
else:
warnings.warn(
f"You are setting {name} directly. This is unstable, as less "
f"validation is performed. You can turn on extra validation "
f"for directly set parameters by setting framework._validate_every_param_set=True."
f"However, this can be brittle, since intermediate states may not be valid.",
category=DeprecationWarning,
)

update_wrapper(_set_property, f)

def _get_property(self):
Expand Down
45 changes: 33 additions & 12 deletions src/hmf/_internals/_framework.py
Expand Up @@ -182,7 +182,15 @@ def get_model(name, mod, **kwargs):
return get_model_(name, mod)(**kwargs)


class Framework:
class _Validator(type):
def __call__(cls, *args, **kwargs):
"""Called when you call MyNewClass() """
obj = type.__call__(cls, *args, **kwargs)
obj.validate()
return obj


class Framework(metaclass=_Validator):
"""
Class representing a coherent framework of component models.
Expand All @@ -198,21 +206,33 @@ class Framework:
defined as a ``parameter`` within the class so it may be set properly.
"""

def __init__(self):
super(Framework, self).__init__()
_validate = True
_validate_every_param_set = False

def validate(self):
pass

def update(self, **kwargs):
"""
Update parameters of the framework with kwargs.
"""
for k, v in list(kwargs.items()):
# If key is just a parameter to the class, just update it.
if hasattr(self, k):
setattr(self, k, kwargs.pop(k))

# If key is a dictionary of parameters to a sub-framework, update the sub-framework
elif k.endswith("_params") and isinstance(getattr(self, k[:-7]), Framework):
getattr(self, k[:-7]).update(**kwargs.pop(k))
self._validate = False
try:
for k, v in list(kwargs.items()):
# If key is just a parameter to the class, just update it.
if hasattr(self, k):
setattr(self, k, kwargs.pop(k))

# If key is a dictionary of parameters to a sub-framework, update the sub-framework
elif k.endswith("_params") and isinstance(
getattr(self, k[:-7]), Framework
):
getattr(self, k[:-7]).update(**kwargs.pop(k))
self._validate = True
self.validate()
except Exception:
self._validate = True
raise

if kwargs:
raise ValueError("Invalid arguments: %s" % kwargs)
Expand Down Expand Up @@ -258,10 +278,11 @@ def parameter_values(self):

@classmethod
def quantities_available(cls):
all_names = cls.get_all_parameter_names()
return [
name
for name in dir(cls)
if name not in cls.get_all_parameter_names()
if name not in all_names
and not name.startswith("__")
and name not in dir(Framework)
]
Expand Down
2 changes: 1 addition & 1 deletion src/hmf/cosmology/cosmo.py
Expand Up @@ -63,7 +63,7 @@ class Cosmology(_framework.Framework):

def __init__(self, cosmo_model=Planck15, cosmo_params=None):
# Call Framework init
super(Cosmology, self).__init__()
super().__init__()

# Set all given parameters
self.cosmo_model = cosmo_model
Expand Down
10 changes: 8 additions & 2 deletions src/hmf/density_field/transfer.py
Expand Up @@ -57,11 +57,11 @@ def __init__(
growth_model=None,
growth_params=None,
use_splined_growth=False,
**kwargs
**kwargs,
):

# Call Cosmology init
super(Transfer, self).__init__(**kwargs)
super().__init__(**kwargs)

# Set all given parameters
self.n = n
Expand Down Expand Up @@ -90,6 +90,12 @@ def __init__(
# ===========================================================================
# Parameters
# ===========================================================================
def validate(self):
super().validate()
assert (
self.lnk_min < self.lnk_max
), f"lnk_min >= lnk_max: {self.lnk_min}, {self.lnk_max}"
assert len(self.k) > 1, f"len(k) < 2: {len(self.k)}"

@parameter("model")
def growth_model(self, val):
Expand Down
10 changes: 9 additions & 1 deletion src/hmf/mass_function/hmf.py
Expand Up @@ -88,7 +88,7 @@ def __init__(
**transfer_kwargs,
):
# Call super init MUST BE DONE FIRST.
super(MassFunction, self).__init__(**transfer_kwargs)
super().__init__(**transfer_kwargs)

# Set all given parameters.
self.hmf_model = hmf_model
Expand All @@ -106,6 +106,14 @@ def __init__(
# ===========================================================================
# PARAMETERS
# ===========================================================================
def validate(self):
super().validate()
assert self.Mmin < self.Mmax, f"Mmin > Mmax: {self.Mmin}, {self.Mmax}"
assert len(self.m) > 0, "mass vector has length zero!"

# Check whether the hmf component validates.
self.hmf

@parameter("res")
def Mmin(self, val):
r"""
Expand Down
20 changes: 20 additions & 0 deletions tests/test_framework.py
Expand Up @@ -6,6 +6,7 @@
from hmf._internals._framework import get_model_
from deprecation import fail_if_not_removed
from hmf import GrowthFactor
from hmf import MassFunction


def test_incorrect_argument():
Expand Down Expand Up @@ -89,3 +90,22 @@ def test_get_model():

def test_growth_plugins():
assert "GenMFGrowth" in GrowthFactor._plugins


def test_validate_inputs():
with pytest.raises(AssertionError):
MassFunction(Mmin=10, Mmax=9)

m = MassFunction(Mmin=10, Mmax=11)
with pytest.raises(AssertionError):
m.update(Mmax=9)

# Without checking on, we can still manually set it, but it will warn us
with pytest.warns(DeprecationWarning):
m.Mmax = 9
m.Mmin = 8

# But with checking on, we can't
m._validate_every_param_set = True
with pytest.raises(AssertionError):
m.Mmax = 7
8 changes: 3 additions & 5 deletions tests/test_mdef.py
Expand Up @@ -96,12 +96,10 @@ def test_from_colossus_name(colossus_cosmo):


def test_change_dndm(colossus_cosmo):
h = MassFunction(
mdef_model="SOVirial", hmf_model="Warren", disable_mass_conversion=False
)

with pytest.warns(UserWarning):
h.mdef
h = MassFunction(
mdef_model="SOVirial", hmf_model="Warren", disable_mass_conversion=False
)

dndm = h.dndm

Expand Down

0 comments on commit ce6520b

Please sign in to comment.