Skip to content

Commit

Permalink
Merge pull request #2036 from adonath/skymodel_improvements
Browse files Browse the repository at this point in the history
Change SkyModel addition to return a SkyModels object
  • Loading branch information
adonath committed Feb 18, 2019
2 parents 8d2498c + 2aeb751 commit 2c2265b
Show file tree
Hide file tree
Showing 3 changed files with 73 additions and 55 deletions.
84 changes: 42 additions & 42 deletions gammapy/cube/models.py
@@ -1,6 +1,5 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
import copy
import operator
import numpy as np
from astropy.utils.decorators import lazyproperty
import astropy.units as u
Expand All @@ -12,17 +11,22 @@
"SkyModelBase",
"SkyModels",
"SkyModel",
"CompoundSkyModel",
"SkyDiffuseCube",
"BackgroundModel",
]


class SkyModelBase(Model):
"""Sky model base class"""

def __add__(self, skymodel):
return CompoundSkyModel(self, skymodel, operator.add)
skymodels = [self]
if isinstance(skymodel, SkyModels):
skymodels += skymodel.skymodels
elif isinstance(skymodel, (SkyModel, SkyDiffuseCube)):
skymodels += [skymodel]
else:
raise NotImplementedError
return SkyModels(skymodels)

def __radd__(self, model):
return self.__add__(model)
Expand All @@ -31,7 +35,7 @@ def __call__(self, lon, lat, energy):
return self.evaluate(lon, lat, energy)


class SkyModels(Model):
class SkyModels(SkyModelBase):
"""Collection of `~gammapy.cube.models.SkyModel`
Parameters
Expand Down Expand Up @@ -66,14 +70,6 @@ def parameters(self):
"""
return self._parameters

@parameters.setter
def parameters(self, parameters):
idx = 0
for skymodel in self.skymodels:
n_par = len(skymodel.parameters.parameters)
skymodel.parameters.parameters = parameters.parameters[idx : idx + n_par]
idx += n_par

@classmethod
def from_xml(cls, xml):
"""Read from XML string."""
Expand Down Expand Up @@ -108,16 +104,44 @@ def to_xml(self, filename):
with filename.open("w") as output:
output.write(xml)

def to_compound_model(self):
"""Return `~gammapy.cube.models.CompoundSkyModel`"""
return np.sum([m for m in self.skymodels])

def evaluate(self, lon, lat, energy):
out = self.skymodels[0].evaluate(lon, lat, energy)
for skymodel in self.skymodels[1:]:
out += skymodel.evaluate(lon, lat, energy)
return out

def __str__(self):
str_ = self.__class__.__name__ + "\n\n"

for idx, skymodel in enumerate(self.skymodels):
str_ += "Component {idx}: {skymodel}\n\n\t".format(idx=idx, skymodel=skymodel)
str_ += "\n\n"

if self.parameters.covariance is not None:
str_ += "\n\nCovariance: \n\n\t"
covariance = self.parameters.covariance_to_table()
str_ += "\n\t".join(covariance.pformat())
return str_

def __iadd__(self, skymodel):
if isinstance(skymodel, SkyModels):
self.skymodels += skymodel.skymodels
elif isinstance(skymodel, (SkyModel, SkyDiffuseCube)):
self.skymodels += [skymodel]
else:
raise NotImplementedError
return self

def __add__(self, skymodel):
skymodels = self.skymodels.copy()
if isinstance(skymodel, SkyModels):
skymodels += skymodel.skymodels
elif isinstance(skymodel, (SkyModel, SkyDiffuseCube)):
skymodels += [skymodel]
else:
raise NotImplementedError
return SkyModels(skymodels)


class SkyModel(SkyModelBase):
"""Sky model component.
Expand All @@ -137,7 +161,6 @@ class SkyModel(SkyModelBase):
name : str
Model identifier
"""

def __init__(self, spatial_model, spectral_model, name="SkyModel"):
self.name = name
self._spatial_model = spatial_model
Expand All @@ -149,37 +172,18 @@ def __init__(self, spatial_model, spectral_model, name="SkyModel"):
@property
def spatial_model(self):
"""`~gammapy.image.models.SkySpatialModel`"""
# propagate sub-covariance
if self.parameters.covariance is not None:
idx = len(self._spatial_model.parameters.parameters)
self._spatial_model.parameters.covariance = self.parameters.covariance[
:idx, :idx
]
return self._spatial_model

@property
def spectral_model(self):
"""`~gammapy.spectrum.models.SpectralModel`"""
# propagate sub-covariance
if self.parameters.covariance is not None:
idx = len(self._spatial_model.parameters.parameters)
self._spectral_model.parameters.covariance = self.parameters.covariance[
idx:, idx:
]
return self._spectral_model

@property
def parameters(self):
"""Parameters (`~gammapy.utils.modeling.Parameters`)"""
return self._parameters

@parameters.setter
def parameters(self, parameters):
self._parameters = parameters
idx = len(self.spatial_model.parameters.parameters)
self._spatial_model.parameters.parameters = parameters.parameters[:idx]
self._spectral_model.parameters.parameters = parameters.parameters[idx:]

def __repr__(self):
fmt = "{}(spatial_model={!r}, spectral_model={!r})"
return fmt.format(
Expand Down Expand Up @@ -208,10 +212,7 @@ def evaluate(self, lon, lat, energy):
"""
val_spatial = self.spatial_model(lon, lat) # pylint:disable=not-callable
val_spectral = self.spectral_model(energy) # pylint:disable=not-callable
val = val_spatial * val_spectral
# TODO: shall remove hard coded return units? If really needed users can
# always do this themselves. For fitting this also adds a performance penalty...
return val.to("cm-2 s-1 TeV-1 deg-2")
return val_spatial * val_spectral


class CompoundSkyModel(SkyModelBase):
Expand All @@ -225,7 +226,6 @@ class CompoundSkyModel(SkyModelBase):
operator : callable
Binary operator to combine the models
"""

def __init__(self, model1, model2, operator):
self.model1 = model1
self.model2 = model2
Expand Down
39 changes: 29 additions & 10 deletions gammapy/cube/tests/test_models.py
Expand Up @@ -88,7 +88,30 @@ def sky_models(sky_model):

@pytest.fixture(scope="session")
def compound_model(sky_model):
return sky_model + sky_model
return CompoundSkyModel(sky_model, sky_model, np.add)


def test_skymodel_addition(sky_model, sky_models, diffuse_model):
result = sky_model + sky_model.copy()
assert isinstance(result, SkyModels)
assert len(result.skymodels) == 2

result = sky_model + sky_models
assert isinstance(result, SkyModels)
assert len(result.skymodels) == 3

result = sky_models + sky_model
assert isinstance(result, SkyModels)
assert len(result.skymodels) == 3

result = sky_models + diffuse_model
assert isinstance(result, SkyModels)
assert len(result.skymodels) == 3

result = sky_models + sky_models
assert isinstance(result, SkyModels)
assert len(result.skymodels) == 4



def test_background_model(background):
Expand All @@ -104,15 +127,6 @@ def test_background_model(background):


class TestSkyModels:
@staticmethod
def test_to_compound_model(sky_models):
model = sky_models.to_compound_model()
assert isinstance(model, CompoundSkyModel)
pars = model.parameters.parameters
assert len(pars) == 12
assert pars[0].name == "lon_0"
assert pars[-1].name == "reference"

@staticmethod
def test_parameters(sky_models):
parnames = ["lon_0", "lat_0", "sigma", "index", "amplitude", "reference"] * 2
Expand All @@ -135,6 +149,11 @@ def test_evaluate(sky_models):
assert q.shape == (5, 3, 4)
assert_allclose(q.value, 3.53758465e-13)

@staticmethod
def test_str(sky_models):
assert "Component 0" in str(sky_models)
assert "Component 1" in str(sky_models)


class TestSkyModel:
@staticmethod
Expand Down
5 changes: 2 additions & 3 deletions tutorials/analysis_3d.ipynb
Expand Up @@ -803,8 +803,7 @@
"outputs": [],
"source": [
"# Checking normalization value (the closer to 1 the better)\n",
"print(\"Model 1: {}\\n\".format(model_combined.model1))\n",
"print(\"Model 2: {}\\n\".format(model_combined.model2))\n",
"print(model_combined)\n",
"print(\"Background model: {}\\n\".format(background_model))"
]
},
Expand Down Expand Up @@ -930,7 +929,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.6.0"
"version": "3.7.0"
}
},
"nbformat": 4,
Expand Down

0 comments on commit 2c2265b

Please sign in to comment.