From f47a9615b8aa6ad571dd6d53ec3551e8411c01c5 Mon Sep 17 00:00:00 2001 From: adonath Date: Thu, 14 Feb 2019 15:00:38 +0100 Subject: [PATCH 1/5] Change SkyModel addition to return a SkyModels object --- gammapy/cube/models.py | 67 ++++++++++++------------------- gammapy/cube/tests/test_models.py | 11 +---- 2 files changed, 26 insertions(+), 52 deletions(-) diff --git a/gammapy/cube/models.py b/gammapy/cube/models.py index 48f8aa0e64..90190761ef 100644 --- a/gammapy/cube/models.py +++ b/gammapy/cube/models.py @@ -1,6 +1,5 @@ # Licensed under a 3-clause BSD style license - see LICENSE.rst import copy -import operator import numpy as np from astropy.utils.decorators import lazyproperty import astropy.units as u @@ -12,7 +11,6 @@ "SkyModelBase", "SkyModels", "SkyModel", - "CompoundSkyModel", "SkyDiffuseCube", "BackgroundModel", ] @@ -20,9 +18,15 @@ class SkyModelBase(Model): """Sky model base class""" - def __add__(self, skymodel): - return CompoundSkyModel(self, skymodel, operator.add) + skymodels = [] + if isinstance(self, SkyModels): + skymodels += self.skymodels + elif isinstance(self, SkyModel): + skymodels += [self, ] + else: + raise NotImplementedError + return SkyModels(skymodels) def __radd__(self, model): return self.__add__(model) @@ -31,7 +35,7 @@ def __call__(self, lon, lat, energy): return self.evaluate(lon, lat, energy) -class SkyModels(Model): +class SkyModels(SkyModelBase): """Collection of `~gammapy.cube.models.SkyModel` Parameters @@ -66,14 +70,6 @@ def parameters(self): """ return self._parameters - @parameters.setter - def parameters(self, parameters): - idx = 0 - for skymodel in self.skymodels: - n_par = len(skymodel.parameters.parameters) - skymodel.parameters.parameters = parameters.parameters[idx : idx + n_par] - idx += n_par - @classmethod def from_xml(cls, xml): """Read from XML string.""" @@ -108,16 +104,27 @@ def to_xml(self, filename): with filename.open("w") as output: output.write(xml) - def to_compound_model(self): - """Return `~gammapy.cube.models.CompoundSkyModel`""" - return np.sum([m for m in self.skymodels]) - def evaluate(self, lon, lat, energy): out = self.skymodels[0].evaluate(lon, lat, energy) for skymodel in self.skymodels[1:]: out += skymodel.evaluate(lon, lat, energy) return out + def __str__(self): + str_ = self.__class__.__name__ + "\n\n" + + for idx, skymodel in enumerate(self.skymodels): + str_ += "Component: {idx}\n\n\t".format(idx=idx) + table = skymodel.parameters.to_table() + str_ += "\n\t".join(table.pformat()) + str_ += "\n\n" + + if self.parameters.covariance is not None: + str_ += "\n\nCovariance: \n\n\t" + covariance = self.parameters.covariance_to_table() + str_ += "\n\t".join(covariance.pformat()) + return str_ + class SkyModel(SkyModelBase): """Sky model component. @@ -137,7 +144,6 @@ class SkyModel(SkyModelBase): name : str Model identifier """ - def __init__(self, spatial_model, spectral_model, name="SkyModel"): self.name = name self._spatial_model = spatial_model @@ -149,23 +155,11 @@ def __init__(self, spatial_model, spectral_model, name="SkyModel"): @property def spatial_model(self): """`~gammapy.image.models.SkySpatialModel`""" - # propagate sub-covariance - if self.parameters.covariance is not None: - idx = len(self._spatial_model.parameters.parameters) - self._spatial_model.parameters.covariance = self.parameters.covariance[ - :idx, :idx - ] return self._spatial_model @property def spectral_model(self): """`~gammapy.spectrum.models.SpectralModel`""" - # propagate sub-covariance - if self.parameters.covariance is not None: - idx = len(self._spatial_model.parameters.parameters) - self._spectral_model.parameters.covariance = self.parameters.covariance[ - idx:, idx: - ] return self._spectral_model @property @@ -173,13 +167,6 @@ def parameters(self): """Parameters (`~gammapy.utils.modeling.Parameters`)""" return self._parameters - @parameters.setter - def parameters(self, parameters): - self._parameters = parameters - idx = len(self.spatial_model.parameters.parameters) - self._spatial_model.parameters.parameters = parameters.parameters[:idx] - self._spectral_model.parameters.parameters = parameters.parameters[idx:] - def __repr__(self): fmt = "{}(spatial_model={!r}, spectral_model={!r})" return fmt.format( @@ -208,10 +195,7 @@ def evaluate(self, lon, lat, energy): """ val_spatial = self.spatial_model(lon, lat) # pylint:disable=not-callable val_spectral = self.spectral_model(energy) # pylint:disable=not-callable - val = val_spatial * val_spectral - # TODO: shall remove hard coded return units? If really needed users can - # always do this themselves. For fitting this also adds a performance penalty... - return val.to("cm-2 s-1 TeV-1 deg-2") + return val_spatial * val_spectral class CompoundSkyModel(SkyModelBase): @@ -225,7 +209,6 @@ class CompoundSkyModel(SkyModelBase): operator : callable Binary operator to combine the models """ - def __init__(self, model1, model2, operator): self.model1 = model1 self.model2 = model2 diff --git a/gammapy/cube/tests/test_models.py b/gammapy/cube/tests/test_models.py index 04b516c7f3..b71799c69d 100644 --- a/gammapy/cube/tests/test_models.py +++ b/gammapy/cube/tests/test_models.py @@ -88,7 +88,7 @@ def sky_models(sky_model): @pytest.fixture(scope="session") def compound_model(sky_model): - return sky_model + sky_model + return CompoundSkyModel(sky_model, sky_model, np.add) def test_background_model(background): @@ -104,15 +104,6 @@ def test_background_model(background): class TestSkyModels: - @staticmethod - def test_to_compound_model(sky_models): - model = sky_models.to_compound_model() - assert isinstance(model, CompoundSkyModel) - pars = model.parameters.parameters - assert len(pars) == 12 - assert pars[0].name == "lon_0" - assert pars[-1].name == "reference" - @staticmethod def test_parameters(sky_models): parnames = ["lon_0", "lat_0", "sigma", "index", "amplitude", "reference"] * 2 From 0950eb58c8ba37cce4323c7a11bed9eaa8cc3c0a Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Sun, 17 Feb 2019 12:14:42 +0100 Subject: [PATCH 2/5] Add unit tests --- gammapy/cube/tests/test_models.py | 23 +++++++++++++++++++++++ 1 file changed, 23 insertions(+) diff --git a/gammapy/cube/tests/test_models.py b/gammapy/cube/tests/test_models.py index b71799c69d..1ae49f1585 100644 --- a/gammapy/cube/tests/test_models.py +++ b/gammapy/cube/tests/test_models.py @@ -91,6 +91,29 @@ def compound_model(sky_model): return CompoundSkyModel(sky_model, sky_model, np.add) +def test_skymodel_addition(sky_model, sky_models, diffuse_model): + result = sky_model + sky_model.copy() + assert isinstance(result, SkyModels) + assert len(result.skymodels) == 2 + + result = sky_model + sky_models + assert isinstance(result, SkyModels) + assert len(result.skymodels) == 3 + + result = sky_models + sky_model + assert isinstance(result, SkyModels) + assert len(result.skymodels) == 3 + + result = sky_models + diffuse_model + assert isinstance(result, SkyModels) + assert len(result.skymodels) == 3 + + result = sky_models + sky_models + assert isinstance(result, SkyModels) + assert len(result.skymodels) == 4 + + + def test_background_model(background): bkg1 = BackgroundModel(background, norm=2.0).evaluate() assert_allclose(bkg1.data[0][0][0], background.data[0][0][0] * 2.0, rtol=1e-3) From 3f95fafdb597bad7546aaa994711c078481fae09 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Sun, 17 Feb 2019 12:15:17 +0100 Subject: [PATCH 3/5] Fix SkyModel.__add__ copy behaviour --- gammapy/cube/models.py | 33 +++++++++++++++++++++++++-------- 1 file changed, 25 insertions(+), 8 deletions(-) diff --git a/gammapy/cube/models.py b/gammapy/cube/models.py index 90190761ef..f6a7d8a881 100644 --- a/gammapy/cube/models.py +++ b/gammapy/cube/models.py @@ -19,11 +19,11 @@ class SkyModelBase(Model): """Sky model base class""" def __add__(self, skymodel): - skymodels = [] - if isinstance(self, SkyModels): - skymodels += self.skymodels - elif isinstance(self, SkyModel): - skymodels += [self, ] + skymodels = [self] + if isinstance(skymodel, SkyModels): + skymodels += skymodel.skymodels + elif isinstance(skymodel, (SkyModel, SkyDiffuseCube)): + skymodels += [skymodel] else: raise NotImplementedError return SkyModels(skymodels) @@ -114,9 +114,7 @@ def __str__(self): str_ = self.__class__.__name__ + "\n\n" for idx, skymodel in enumerate(self.skymodels): - str_ += "Component: {idx}\n\n\t".format(idx=idx) - table = skymodel.parameters.to_table() - str_ += "\n\t".join(table.pformat()) + str_ += "Component {idx}: {skymodel}\n\n\t".format(idx=idx, skymodel=skymodel) str_ += "\n\n" if self.parameters.covariance is not None: @@ -125,6 +123,25 @@ def __str__(self): str_ += "\n\t".join(covariance.pformat()) return str_ + def __iadd__(self, skymodel): + if isinstance(skymodel, SkyModels): + self.skymodels += skymodel.skymodels + elif isinstance(skymodel, (SkyModel, SkyDiffuseCube)): + self.skymodels += [skymodel] + else: + raise NotImplementedError + return self + + def __add__(self, skymodel): + skymodels = self.skymodels.copy() + if isinstance(skymodel, SkyModels): + skymodels += skymodel.skymodels + elif isinstance(skymodel, (SkyModel, SkyDiffuseCube)): + skymodels += [skymodel] + else: + raise NotImplementedError + return SkyModels(skymodels) + class SkyModel(SkyModelBase): """Sky model component. From 7fa78726aa073a0c6d62a62d34d2e206790bb730 Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Sun, 17 Feb 2019 12:15:40 +0100 Subject: [PATCH 4/5] Adapt analysis_3d notebook --- tutorials/analysis_3d.ipynb | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/tutorials/analysis_3d.ipynb b/tutorials/analysis_3d.ipynb index 3a0f676658..c2a34075b5 100644 --- a/tutorials/analysis_3d.ipynb +++ b/tutorials/analysis_3d.ipynb @@ -803,8 +803,7 @@ "outputs": [], "source": [ "# Checking normalization value (the closer to 1 the better)\n", - "print(\"Model 1: {}\\n\".format(model_combined.model1))\n", - "print(\"Model 2: {}\\n\".format(model_combined.model2))\n", + "print(model_combined)\n", "print(\"Background model: {}\\n\".format(background_model))" ] }, @@ -930,7 +929,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.6.0" + "version": "3.7.0" } }, "nbformat": 4, From 2aeb7518d7a9cbaaa9a3e4ea8e21fe89dc449f2a Mon Sep 17 00:00:00 2001 From: Axel Donath Date: Sun, 17 Feb 2019 12:20:47 +0100 Subject: [PATCH 5/5] Add skymodel string test --- gammapy/cube/tests/test_models.py | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/gammapy/cube/tests/test_models.py b/gammapy/cube/tests/test_models.py index 1ae49f1585..e185883583 100644 --- a/gammapy/cube/tests/test_models.py +++ b/gammapy/cube/tests/test_models.py @@ -149,6 +149,11 @@ def test_evaluate(sky_models): assert q.shape == (5, 3, 4) assert_allclose(q.value, 3.536776513153229e-13) + @staticmethod + def test_str(sky_models): + assert "Component 0" in str(sky_models) + assert "Component 1" in str(sky_models) + class TestSkyModel: @staticmethod