Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add datasets_names attribute to cube models #2782

Merged
merged 5 commits into from Feb 20, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
52 changes: 31 additions & 21 deletions gammapy/cube/fit.py
Expand Up @@ -14,7 +14,7 @@
from gammapy.irf import EDispKernel, EffectiveAreaTable
from gammapy.maps import Map, MapAxis
from gammapy.modeling import Dataset, Parameters
from gammapy.modeling.models import BackgroundModel, Models
from gammapy.modeling.models import BackgroundModel, Models, SkyModel
from gammapy.spectrum import SpectrumDataset, SpectrumDatasetOnOff
from gammapy.stats import cash, cash_sum_cython, wstat
from gammapy.utils.random import get_random_state
Expand Down Expand Up @@ -191,11 +191,22 @@ def models(self):
return self._models

@models.setter
def models(self, value):
if value is not None:
self._models = Models(value)
else:
def models(self, models):
if models is None:
self._models = None
else:
if isinstance(models, (BackgroundModel, SkyModel)):
models = [models]
elif isinstance(models,(Models, list)):
models = list(models)
else:
raise TypeError("Invalid models")
models_list = [
model
for model in models
if self.name in model.datasets_names or model.datasets_names == "all"
]
self._models = Models(models_list)

if self.models is not None:
for model in self.models:
Expand Down Expand Up @@ -300,12 +311,13 @@ def from_geoms(
empty_maps : `MapDataset`
A MapDataset containing zero filled maps
"""
name = make_name(name)
kwargs = kwargs.copy()
kwargs["name"] = name
kwargs["counts"] = Map.from_geom(geom, unit="")

background = Map.from_geom(geom, unit="")
kwargs["models"] = Models([BackgroundModel(background)])
kwargs["models"] = Models([BackgroundModel(background, datasets_names=[name])])
kwargs["exposure"] = Map.from_geom(geom_exposure, unit="m2 s")
kwargs["edisp"] = EDispMap.from_geom(geom_edisp)
kwargs["psf"] = PSFMap.from_geom(geom_psf)
Expand Down Expand Up @@ -660,6 +672,7 @@ def from_hdulist(cls, hdulist, name=None):
dataset : `MapDataset`
Map dataset.
"""
name = make_name(name)
kwargs = {"name": name}

if "COUNTS" in hdulist:
Expand All @@ -670,7 +683,9 @@ def from_hdulist(cls, hdulist, name=None):

if "BACKGROUND" in hdulist:
background_map = Map.from_hdulist(hdulist, hdu="background")
kwargs["models"] = Models([BackgroundModel(background_map)])
kwargs["models"] = Models(
[BackgroundModel(background_map, datasets_names=[name])]
)

if "EDISP_MATRIX" in hdulist:
kwargs["edisp"] = EDispKernel.from_hdulist(
Expand Down Expand Up @@ -740,14 +755,13 @@ def read(cls, filename, name=None):
def from_dict(cls, data, components, models):
"""Create from dicts and models list generated from YAML serialization."""
dataset = cls.read(data["filename"], name=data["name"])
bkg_name = data["background"]
model_names = data["models"]
models_list = [model for model in models if model.name in model_names]
models = Models(models_list)

for component in components["components"]:
if component["type"] == "BackgroundModel":
if component["name"] == bkg_name:
if (
dataset.name in component.get("datasets_names", [])
or "datasets_names" not in component
):
if "filename" not in component:
component["map"] = dataset.background_model.map
background_model = BackgroundModel.from_dict(component)
Expand All @@ -758,16 +772,9 @@ def from_dict(cls, data, components, models):

def to_dict(self, filename=""):
"""Convert to dict for YAML serialization."""
if self.models is None:
models = []
else:
models = [_.name for _ in self.models]

return {
"name": self.name,
"type": self.tag,
"models": models,
"background": self.background_model.name,
"filename": str(filename),
}

Expand Down Expand Up @@ -869,8 +876,9 @@ def to_image(self, spectrum=None, name=None):
dataset : `MapDataset`
Map dataset containing images.
"""
name = make_name(name)
kwargs = {}
kwargs["name"] = make_name(name)
kwargs["name"] = name
kwargs["gti"] = self.gti

if self.mask_safe is not None:
Expand All @@ -892,7 +900,9 @@ def to_image(self, spectrum=None, name=None):
if self.background_model is not None:
background = self.background_model.evaluate() * mask_safe
background = background.sum_over_axes(keepdims=True)
kwargs["models"] = Models([BackgroundModel(background)])
kwargs["models"] = Models(
[BackgroundModel(background, datasets_names=[name])]
)

if self.psf is not None:
kwargs["psf"] = self.psf.to_image(spectrum=spectrum, keepdims=True)
Expand Down
2 changes: 1 addition & 1 deletion gammapy/cube/make.py
Expand Up @@ -231,7 +231,7 @@ def run(self, dataset, observation):

if "background" in self.selection:
background_map = self.make_background(dataset.counts.geom, observation)
kwargs["models"] = BackgroundModel(background_map)
kwargs["models"] = BackgroundModel(background_map, datasets_names=[dataset.name])

if "psf" in self.selection:
psf = self.make_psf(dataset.psf.psf_map.geom, observation)
Expand Down
24 changes: 20 additions & 4 deletions gammapy/modeling/models/cube.py
Expand Up @@ -61,6 +61,7 @@ def __init__(
temporal_model=None,
name=None,
apply_irf=None,
datasets_names="all",
):
self.spatial_model = spatial_model
self.spectral_model = spectral_model
Expand All @@ -73,6 +74,7 @@ def __init__(
self.apply_irf = {"exposure": True, "psf": True, "edisp": True}
if apply_irf is not None:
self.apply_irf.update(apply_irf)
self.datasets_names = datasets_names

@property
def name(self):
Expand Down Expand Up @@ -205,6 +207,7 @@ def copy(self, name=None, **kwargs):
kwargs.setdefault("spatial_model", spatial_model)
kwargs.setdefault("temporal_model", temporal_model)
kwargs.setdefault("apply_irf", self.apply_irf)
kwargs.setdefault("datasets_names", self.datasets_names)

return self.__class__(**kwargs)

Expand All @@ -224,6 +227,9 @@ def to_dict(self):
if self.apply_irf != {"exposure": True, "psf": True, "edisp": True}:
data["apply_irf"] = self.apply_irf

if self.datasets_names != "all":
data["datasets_names"] = self.datasets_names

return data

@classmethod
Expand Down Expand Up @@ -260,6 +266,7 @@ def from_dict(cls, data):
apply_irf=data.get(
"apply_irf", {"exposure": True, "psf": True, "edisp": True}
),
datasets_names=data.get("datasets_names", "all"),
)

def __str__(self):
Expand Down Expand Up @@ -328,6 +335,7 @@ def __init__(
name=None,
filename=None,
apply_irf=None,
datasets_names="all",
):

self._name = make_name(name)
Expand All @@ -348,7 +356,7 @@ def __init__(
self.apply_irf = {"exposure": True, "psf": True, "edisp": True}
if apply_irf is not None:
self.apply_irf.update(apply_irf)

self.datasets_names = datasets_names
super().__init__(norm=norm, tilt=tilt, reference=reference)

@property
Expand Down Expand Up @@ -435,6 +443,8 @@ def from_dict(cls, data):
"apply_irf", {"exposure": True, "psf": True, "edisp": True}
)
model.apply_irf.update(apply_irf)
model.datasets_names = data.get("datasets_names", "all")

return model

def to_dict(self):
Expand All @@ -447,6 +457,8 @@ def to_dict(self):
data["parameters"] = data.pop("parameters")
if self.apply_irf != {"exposure": True, "psf": True, "edisp": True}:
data["apply_irf"] = self.apply_irf
if self.datasets_names != "all":
data["datasets_names"] = self.datasets_names

return data

Expand Down Expand Up @@ -492,6 +504,7 @@ def __init__(
reference=reference.quantity,
name=None,
filename=None,
datasets_names="all",
):
axis = map.geom.get_axis_by_name("energy")
if axis.node_type != "edges":
Expand All @@ -501,7 +514,7 @@ def __init__(

self._name = make_name(name)
self.filename = filename

self.datasets_names = datasets_names
super().__init__(norm=norm, tilt=tilt, reference=reference)

@property
Expand Down Expand Up @@ -537,6 +550,8 @@ def to_dict(self):
if self.filename is not None:
data["filename"] = self.filename
data["parameters"] = data.pop("parameters")
if self.datasets_names != "all":
data["datasets_names"] = self.datasets_names
return data

@classmethod
Expand All @@ -547,8 +562,9 @@ def from_dict(cls, data):
map = data["map"]
else:
raise ValueError("Requires either filename or `Map` object")

model = cls(map=map, name=data["name"])
model = cls(
map=map, name=data["name"], datasets_names=data.get("datasets_names", "all")
)
model._update_from_dict(data)
return model

Expand Down
12 changes: 12 additions & 0 deletions gammapy/modeling/models/tests/test_cube.py
Expand Up @@ -365,6 +365,18 @@ def test_processing(diffuse_model):
assert out["apply_irf"] == {"exposure": True, "psf": True, "edisp": False}
diffuse_model.apply_irf["edisp"] = True

@staticmethod
def test_datasets_name(diffuse_model):
assert diffuse_model.datasets_names == "all"

diffuse_model.datasets_names = ["1", "2"]
out = diffuse_model.to_dict()
assert out["datasets_names"] == ["1", "2"]

diffuse_model.datasets_names = "all"
out = diffuse_model.to_dict()
assert "datasets_names" not in out


class TestSkyDiffuseCubeMapEvaluator:
@staticmethod
Expand Down
6 changes: 5 additions & 1 deletion gammapy/modeling/tests/test_serialize_yaml.py
Expand Up @@ -115,9 +115,9 @@ def test_datasets_to_io(tmp_path):
datasets = Datasets.read(filedata, filemodel)

assert len(datasets) == 2
assert len(datasets.parameters) == 22

dataset0 = datasets[0]
assert dataset0.name == "gc"
assert dataset0.counts.data.sum() == 6824
assert_allclose(dataset0.exposure.data.sum(), 2072125400000.0, atol=0.1)
assert dataset0.psf is not None
Expand All @@ -128,6 +128,7 @@ def test_datasets_to_io(tmp_path):
assert dataset0.background_model.name == "background_irf_gc"

dataset1 = datasets[1]
assert dataset1.name == "g09"
assert dataset1.background_model.name == "background_irf_g09"

assert (
Expand All @@ -150,6 +151,9 @@ def test_datasets_to_io(tmp_path):
datasets_read = Datasets.read(
tmp_path / "written_datasets.yaml", tmp_path / "written_models.yaml"
)

assert len(datasets.parameters) == 22

assert len(datasets_read) == 2
dataset0 = datasets_read[0]
assert dataset0.counts.data.sum() == 6824
Expand Down