Skip to content

Commit

Permalink
normalised model properties
Browse files Browse the repository at this point in the history
  • Loading branch information
steven-murray committed Oct 31, 2017
1 parent 2546888 commit 72a35cb
Show file tree
Hide file tree
Showing 3 changed files with 30 additions and 22 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.rst
Original file line number Diff line number Diff line change
@@ -1,6 +1,13 @@
Releases
========

v3.0.1 [31st Oct 2017]
----------------------
**Enhancement**

- Normalised all <>_model properties to be actual classes, rather than either class or string.
- Added consistent checking of dictionaries for <>_params parameters.

v3.0.0 [7th June 2017]
----------------------
**Features**
Expand Down
9 changes: 7 additions & 2 deletions hmf/_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ def cached_quantity(f):
value on all subsequent calls. If `a_param` is modified, the
calculation of either `a_quantity` and `a_child_quantity` will be re-performed when requested.
"""

name = f.__name__


def _get_property(self):
# Location of the property to be accessed
prop = hidden_loc(self, name)
Expand Down Expand Up @@ -127,7 +127,6 @@ def _del_property(self):

return property(_get_property, None, _del_property)


def obj_eq(ob1, ob2):
try:
if ob1 == ob2:
Expand Down Expand Up @@ -182,6 +181,12 @@ def param(f):
name = f.__name__

def _set_property(self, val):
# Here put any custom code that should be run, dependent on the type of parameter
if name.endswith("_params"):
if not isinstance(val, dict):
raise ValueError("%s must be a dictionary"%name)


prop = hidden_loc(self, name)

# The following does any complex setting that is written into the code
Expand Down
36 changes: 16 additions & 20 deletions hmf/hmf.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from numpy import issubclass_
logger = logging.getLogger('hmf')
from .filters import TopHat, Filter
from ._framework import get_model
from ._framework import get_model, get_model_
from scipy.optimize import minimize
from scipy.interpolate import InterpolatedUnivariateSpline as spline
import warnings
Expand Down Expand Up @@ -121,11 +121,14 @@ def filter_model(self, val):
"""
A model for the window/filter function.
:type: str or :class:`hmf.filters.Filter` subclass
:type: :class:`hmf.filters.Filter` subclass
"""
if not issubclass_(val, Filter) and not isinstance(val, str):
raise ValueError("filter must be a Filter or string, got %s" % type(val))
return val
elif isinstance(val, str):
return get_model_(val, "hmf.filters")
else:
return val

@parameter("param")
def filter_params(self, val):
Expand Down Expand Up @@ -164,7 +167,10 @@ def hmf_model(self, val):
"""
if not issubclass_(val, ff.FittingFunction) and not isinstance(val, str):
raise ValueError("hmf_model must be a ff.FittingFunction or string, got %s" % type(val))
return val
elif isinstance(val, str):
return get_model_(val, "hmf.fitting_functions")
else:
return val

@parameter("param")
def hmf_params(self, val):
Expand All @@ -173,8 +179,6 @@ def hmf_params(self, val):
:type: dict
"""
if not isinstance(val, dict):
raise ValueError("hmf_params must be a dictionary")
return val

@parameter("param")
Expand Down Expand Up @@ -258,33 +262,24 @@ def mean_density(self):
"""
return self.mean_density0 * (1 + self.z) ** 3


@cached_quantity
def hmf(self):
"""
Instantiated model for the hmf fitting function.
"""
if issubclass_(self.hmf_model, ff.FittingFunction):
return self.hmf_model(m=self.m, nu2=self.nu, z=self.z,
return self.hmf_model(m=self.m, nu2=self.nu, z=self.z,
delta_halo=self.delta_halo, omegam_z=self.cosmo.Om(self.z),
delta_c=self.delta_c, n_eff=self.n_eff,
** self.hmf_params)
elif isinstance(self.hmf_model, str):
return get_model(self.hmf_model, "hmf.fitting_functions",
m=self.m, nu2=self.nu, z=self.z,
delta_halo=self.delta_halo, omegam_z=self.cosmo.Om(self.z),
delta_c=self.delta_c, n_eff=self.n_eff,
** self.hmf_params)


@cached_quantity
def filter(self):
"""
Instantiated model for filter/window functions.
"""
if issubclass_(self.filter_model, Filter):
return self.filter_model(self.k,self._unnormalised_power, **self.filter_params)
elif isinstance(self.filter_model, str):
return get_model(self.filter_model, "hmf.filters", k=self.k,
power=self._unnormalised_power, **self.filter_params)
return self.filter_model(self.k,self._unnormalised_power, **self.filter_params)

@cached_quantity
def m(self):
Expand All @@ -294,7 +289,7 @@ def m(self):
@cached_quantity
def M(self):
"Masses (alias of m, deprecated)"
return self.m
raise AttributeError("Use of M has been deprecated for a while, and is now removed. Use m.")

@cached_quantity
def delta_halo(self):
Expand All @@ -312,6 +307,7 @@ def _unn_sigma0(self):
"""
return self.filter.sigma(self.radii)


@cached_quantity
def _sigma_0(self):
"""
Expand Down

0 comments on commit 72a35cb

Please sign in to comment.