Skip to content

Commit

Permalink
update to pydantic 2.0
Browse files Browse the repository at this point in the history
Signed-off-by: Atreyee Sinha <asinha@ucm.es>
  • Loading branch information
AtreyeeS committed Jan 23, 2024
1 parent ba07ad1 commit 9ac0487
Show file tree
Hide file tree
Showing 3 changed files with 131 additions and 134 deletions.
5 changes: 4 additions & 1 deletion gammapy/datasets/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -616,7 +616,7 @@ def energy_range_total(self):
def meta(self):
"""Return metadata container."""
if self._meta is None:
self._meta = MapDatasetMetaData.from_default()
self._meta = MapDatasetMetaData()
return self._meta

def npred(self):
Expand Down Expand Up @@ -1033,6 +1033,9 @@ def stack(self, other, nan_to_num=True):
elif other.meta_table:
self.meta_table = other.meta_table.copy()

if self.meta and other.meta:
self.meta.stack(other.meta)

def stat_array(self):
"""Statistic function value per bin given the current model parameters."""
return cash(n_on=self.counts.data, mu_on=self.npred().data)
Expand Down
187 changes: 78 additions & 109 deletions gammapy/datasets/metadata.py
Original file line number Diff line number Diff line change
@@ -1,111 +1,65 @@
import logging
from typing import Optional, Union
import numpy as np
from astropy.coordinates import SkyCoord
from pydantic import ValidationError, validator
from gammapy.utils.metadata import CreatorMetaData, MetaData
from typing import ClassVar, Literal, Optional, Union
from pydantic import ConfigDict
from gammapy.utils.metadata import (
METADATA_FITS_KEYS,
CreatorMetaData,
MetaData,
PointingInfoMetaData,
)

__all__ = ["MapDatasetMetaData"]

MapDataset_METADATA_FITS_KEYS = {
"MapDataset": {
"creation": "CREATION",
"instrument": "INSTRUM",
"telescope": "TELESCOP",
"observation_mode": "OBS_MODE",
"pointing": "POINTING",
"obs_ids": "OBS_IDS",
"event_types": "EVT_TYPE",
"optional": "OPTIONAL",
},
}

METADATA_FITS_KEYS.update(MapDataset_METADATA_FITS_KEYS)


class MapDatasetMetaData(MetaData):
"""Metadata containing information about the GTI.
Parameters
----------
creation : `~gammapy.utils.CreatorMetaData`
the creation metadata
creation : `~gammapy.utils.CreatorMetaData`, optional
The creation metadata.
instrument : str
the instrument used during observation
the instrument used during observation.
telescope : str
The specific telescope subarray
The specific telescope subarray.
observation_mode : str
observing mode
observing mode.
pointing : ~astropy.coordinates.SkyCoord
Telescope pointing direction
Telescope pointing direction.
obs_ids : int
Observation ids stacked in the dataset
Observation ids stacked in the dataset.
event_types : int
Event types used in analysis
Event types used in analysis.
optional : dict
Any other meta information
Additional optional metadata.
"""

creation: Optional[CreatorMetaData]
instrument: Optional[str]
telescope: Optional[Union[str, list[str]]]
observation_mode: Optional[Union[str, list]]
pointing: Optional[Union[SkyCoord, list[SkyCoord]]]
obs_ids: Optional[Union[int, list[int]]]
event_type: Optional[Union[int, list[int]]]
optional: Optional[dict]

@validator("creation")
def validate_creation(cls, v):
if v is None:
return CreatorMetaData.from_default()
elif isinstance(v, CreatorMetaData):
return v
else:
raise ValidationError(
f"Incorrect pointing. Expect CreatorMetaData got {type(v)} instead."
)

@validator("instrument")
def validate_instrument(cls, v):
if isinstance(v, str):
return v
elif v is None:
return v
else:
raise ValidationError(
f"Incorrect instrument. Expect str got {type(v)} instead."
)

@validator("telescope")
def validate_telescope(cls, v):
if isinstance(v, str):
return v
elif v is None:
return v
elif all(isinstance(_, str) for _ in v):
return v
else:
raise ValidationError(
f"Incorrect telescope type. Expect str got {type(v)} instead."
)

@validator("pointing")
def validate_pointing(cls, v):
if v is None:
return SkyCoord(np.nan, np.nan, unit="deg", frame="icrs")
elif isinstance(v, SkyCoord):
return v
elif all(isinstance(_, SkyCoord) for _ in v):
return v
else:
raise ValidationError(
f"Incorrect pointing. Expect SkyCoord got {type(v)} instead."
)

@validator("obs_ids", "event_type")
def validate_obs_ids(cls, v):
if v is None:
return -999
elif isinstance(v, int):
return v
elif all(isinstance(_, int) for _ in v):
return v
else:
raise ValidationError(
f"Incorrect pointing. Expect int got {type(v)} instead."
)
model_config = ConfigDict(coerce_numbers_to_str=True)

@classmethod
def from_default(cls):
"""Creation metadata containing Gammapy version."""
creation = CreatorMetaData.from_default()
return cls(creation=creation)
_tag: ClassVar[Literal["MapDataset"]] = "MapDataset"
creation: Optional[CreatorMetaData] = CreatorMetaData()
instrument: Optional[str] = None
telescope: Optional[Union[str, list[str]]] = None
observation_mode: Optional[Union[str, list]] = None
pointing: Optional[Union[PointingInfoMetaData, list[PointingInfoMetaData]]] = None
obs_ids: Optional[Union[str, list[str]]] = None
event_type: Optional[Union[str, list[str]]] = None
optional: Optional[dict] = None

def stack(self, other):
kwargs = {}
Expand All @@ -115,35 +69,50 @@ def stack(self, other):
logging.warning(
f"Stacking data from different instruments {self.instrument} and {other.instrument}"
)
tel = self.telescope
if isinstance(tel, str):
tel = [tel]
if other.telescope not in tel:
tel.append(other.telescope)
if self.telescope is not None:
tel = self.telescope
if isinstance(tel, str):
tel = [tel]
if other.telescope not in tel:
tel.append(other.telescope)
else:
tel = other.telescope
kwargs["telescope"] = tel

observation_mode = self.observation_mode
if isinstance(observation_mode, str):
observation_mode = [observation_mode]
observation_mode.append(other.observation_mode)
if self.observation_mode is not None:
observation_mode = self.observation_mode
if isinstance(observation_mode, str):
observation_mode = [observation_mode]
observation_mode.append(other.observation_mode)
else:
observation_mode = other.observation_mode
kwargs["observation_mode"] = observation_mode

pointing = self.pointing
if isinstance(pointing, SkyCoord):
pointing = [pointing]
pointing.append(other.pointing)
if self.pointing is not None:
pointing = self.pointing
if isinstance(pointing, PointingInfoMetaData):
pointing = [pointing]
pointing.append(other.pointing)
else:
pointing = other.pointing
kwargs["pointing"] = pointing

obs_ids = self.obs_ids
if isinstance(obs_ids, int):
obs_ids = [obs_ids]
obs_ids.append(other.obs_ids)
if self.obs_ids is not None:
obs_ids = self.obs_ids
if isinstance(obs_ids, str):
obs_ids = [obs_ids]
obs_ids.append(other.obs_ids)
else:
obs_ids = other.obs_ids
kwargs["obs_ids"] = obs_ids

event_type = self.event_type
if not isinstance(event_type, list):
event_type = [event_type]
event_type.append(other.event_type)
if self.event_type is not None:
event_type = self.event_type
if isinstance(event_type, str):
event_type = [event_type]
event_type.append(other.event_type)
else:
event_type = other.event_type
kwargs["event_type"] = event_type

if self.optional:
Expand Down
73 changes: 49 additions & 24 deletions gammapy/datasets/tests/test_metadata.py
Original file line number Diff line number Diff line change
@@ -1,23 +1,25 @@
import pytest
import numpy as np
from numpy.testing import assert_allclose
from astropy.coordinates import SkyCoord
from pydantic import ValidationError
from gammapy.datasets import MapDatasetMetaData
from gammapy.utils.metadata import PointingInfoMetaData


def test_mapdataset_meta_from_default():
meta = MapDatasetMetaData.from_default()

def test_meta_default():
meta = MapDatasetMetaData()
assert meta.creation.creator.split()[0] == "Gammapy"
assert meta.instrument is None
assert meta.event_type is None


def test_mapdataset_metadata():
position = SkyCoord(83.6287, 22.0147, unit="deg", frame="icrs")
input = {
"telescope": "cta-north",
"instrument": "lst",
"observation_mode": "wobble",
"pointing": SkyCoord(83.6287, 22.0147, unit="deg", frame="icrs"),
"pointing": PointingInfoMetaData(radec_mean=position),
"obs_ids": 112,
"optional": dict(test=0.5, other=True),
}
Expand All @@ -26,24 +28,21 @@ def test_mapdataset_metadata():
assert meta.telescope == "cta-north"
assert meta.instrument == "lst"
assert meta.observation_mode == "wobble"
assert_allclose(meta.pointing.dec.value, 22.0147)
assert_allclose(meta.pointing.ra.deg, 83.6287)
assert meta.obs_ids == 112
assert_allclose(meta.pointing.radec_mean.dec.value, 22.0147)
assert_allclose(meta.pointing.radec_mean.ra.deg, 83.6287)
assert meta.obs_ids == "112"
assert meta.optional["other"] is True
assert meta.creation.creator.split()[0] == "Gammapy"
assert meta.event_type is None

with pytest.raises(ValidationError):
meta.pointing = 2.0

with pytest.raises(ValidationError):
meta.instrument = ["cta", "hess"]

meta.pointing = None
assert isinstance(meta.pointing, SkyCoord)
assert np.isnan(meta.pointing.ra.deg)

input_bad = input.copy()
input_bad["obs_ids"] = "bad"
input_bad["bad"] = position

with pytest.raises(ValueError):
MapDatasetMetaData(**input_bad)
Expand All @@ -55,28 +54,34 @@ def test_mapdataset_metadata_lists():
"instrument": "lst",
"observation_mode": "wobble",
"pointing": [
SkyCoord(83.6287, 22.0147, unit="deg", frame="icrs"),
SkyCoord(83.1287, 22.5147, unit="deg", frame="icrs"),
PointingInfoMetaData(
radec_mean=SkyCoord(83.6287, 22.0147, unit="deg", frame="icrs")
),
PointingInfoMetaData(
radec_mean=SkyCoord(83.1287, 22.5147, unit="deg", frame="icrs")
),
],
"obs_ids": [111, 222],
}
meta = MapDatasetMetaData(**input)
assert meta.telescope == "cta-north"
assert meta.instrument == "lst"
assert meta.observation_mode == "wobble"
assert_allclose(meta.pointing[0].dec.value, 22.0147)
assert_allclose(meta.pointing[1].ra.deg, 83.1287)
assert meta.obs_ids == [111, 222]
assert_allclose(meta.pointing[0].radec_mean.dec.value, 22.0147)
assert_allclose(meta.pointing[1].radec_mean.ra.deg, 83.1287)
assert meta.obs_ids == ["111", "222"]
assert meta.optional is None
assert meta.event_type == -999
assert meta.event_type is None


def test_mapdataset_metadata_stack():
input1 = {
"telescope": "a",
"instrument": "H.E.S.S.",
"observation_mode": "wobble",
"pointing": SkyCoord(83.6287, 22.5147, unit="deg", frame="icrs"),
"pointing": PointingInfoMetaData(
radec_mean=SkyCoord(83.6287, 22.5147, unit="deg", frame="icrs")
),
"obs_ids": 111,
"optional": dict(test=0.5, other=True),
}
Expand All @@ -85,7 +90,9 @@ def test_mapdataset_metadata_stack():
"telescope": "b",
"instrument": "H.E.S.S.",
"observation_mode": "wobble",
"pointing": SkyCoord(83.6287, 22.0147, unit="deg", frame="icrs"),
"pointing": PointingInfoMetaData(
radec_mean=SkyCoord(83.6287, 22.0147, unit="deg", frame="icrs")
),
"obs_ids": 112,
"optional": dict(test=0.1, other=False),
}
Expand All @@ -97,7 +104,25 @@ def test_mapdataset_metadata_stack():
assert meta.telescope == ["a", "b"]
assert meta.instrument == "H.E.S.S."
assert meta.observation_mode == ["wobble", "wobble"]
assert_allclose(meta.pointing[1].dec.deg, 22.0147)
assert meta.obs_ids == [111, 112]
assert_allclose(meta.pointing[1].radec_mean.dec.deg, 22.0147)
assert meta.obs_ids == ["111", "112"]
assert meta.optional["other"] == [True, False]
assert len(meta.event_type) == 2
assert meta.event_type is None
assert meta.creation.creator.split()[0] == "Gammapy"


def test_to_header():
input1 = {
"telescope": "a",
"instrument": "H.E.S.S.",
"observation_mode": "wobble",
"pointing": PointingInfoMetaData(
radec_mean=SkyCoord(83.6287, 22.5147, unit="deg", frame="icrs")
),
"obs_ids": 111,
"optional": dict(test=0.5, other=True),
}
meta1 = MapDatasetMetaData(**input1)
hdr = meta1.to_header()
assert hdr["INSTRUM"] == "H.E.S.S."
assert hdr["OBS_IDS"] == "111"

0 comments on commit 9ac0487

Please sign in to comment.