Skip to content

Commit

Permalink
Merge pull request #2616 from fabiopintore/MapDatasetEventSampler_bkg…
Browse files Browse the repository at this point in the history
…sample

Add MapDatasetEventSampler with .sample_background() method
  • Loading branch information
adonath committed Nov 28, 2019
2 parents 7b6318d + 04fecb2 commit 90e00d0
Show file tree
Hide file tree
Showing 2 changed files with 103 additions and 2 deletions.
58 changes: 57 additions & 1 deletion gammapy/cube/simulate.py
@@ -1,17 +1,22 @@
# Licensed under a 3-clause BSD style license - see LICENSE.rst
"""Simulate observations"""
import numpy as np
from astropy.table import Table
import astropy.units as u
from gammapy.cube import (
MapDataset,
PSFKernel,
make_map_background_irf,
make_map_exposure_true_energy,
)
from gammapy.data import EventList
from gammapy.maps import WcsNDMap
from gammapy.modeling.models import BackgroundModel
from gammapy.modeling.models import ConstantTemporalModel

from gammapy.utils.random import get_random_state

__all__ = ["simulate_dataset"]
__all__ = ["simulate_dataset", "MapDatasetEventSampler"]


def simulate_dataset(
Expand Down Expand Up @@ -89,3 +94,54 @@ def simulate_dataset(
dataset.counts = WcsNDMap(geom, counts)

return dataset


class MapDatasetEventSampler:
"""Sample events from a map dataset
Parameters
----------
random_state : {int, 'random-seed', 'global-rng', `~numpy.random.RandomState`}
Defines random number generator initialisation.
Passed to `~gammapy.utils.random.get_random_state`.
"""

def __init__(self, random_state="random-seed"):
self.random_state = get_random_state(random_state)

def sample_background(self, dataset):
"""Sample background
Parameters
----------
dataset : `MapDataset`
Map dataset.
Returns
-------
events : `EventList`
Background events
"""
table = Table()

background = dataset.background_model.evaluate()
n_events = self.random_state.poisson(np.sum(background.data))

# sample position
coords = background.sample_coord(n_events, self.random_state)
table["ENERGY"] = coords["energy"]
table["RA"] = coords.skycoord.icrs.ra.deg
table["DEC"] = coords.skycoord.icrs.dec.deg
table["MC_ID"] = 0

# sample time
time_start, time_stop, time_ref = (
dataset.gti.time_start,
dataset.gti.time_stop,
dataset.gti.time_ref,
)
model = ConstantTemporalModel()
time = model.sample_time(n_events, time_start, time_stop, self.random_state)
table["TIME"] = u.Quantity(((time.mjd - time_ref.mjd) * u.day).to(u.s)).value

return EventList(table)
47 changes: 46 additions & 1 deletion gammapy/cube/tests/test_simulate.py
Expand Up @@ -3,7 +3,9 @@
from numpy.testing import assert_allclose
import astropy.units as u
from astropy.coordinates import SkyCoord
from gammapy.cube import MapDataset, simulate_dataset
from gammapy.cube import MapDataset, simulate_dataset, MapDatasetEventSampler
from gammapy.cube.tests.test_fit import get_map_dataset
from gammapy.data import GTI
from gammapy.irf import load_cta_irfs
from gammapy.maps import MapAxis, WcsGeom
from gammapy.modeling.models import (
Expand Down Expand Up @@ -58,3 +60,46 @@ def test_simulate():
)
assert_allclose(dataset.psf.data[5, 32, 32], 0.04203219)
assert_allclose(dataset.edisp.data.data[10, 10], 0.85944298, rtol=1e-5)


def dataset_maker():
position = SkyCoord(0.0, 0.0, frame="galactic", unit="deg")
energy_axis = MapAxis.from_bounds(
1, 100, nbin=30, unit="TeV", name="energy", interp="log"
)

spatial_model = GaussianSpatialModel(
lon_0="0 deg", lat_0="0 deg", sigma="0.2 deg", frame="galactic"
)

spectral_model = PowerLawSpectralModel(amplitude="1e-11 cm-2 s-1 TeV-1")
skymodel = SkyModel(spatial_model=spatial_model, spectral_model=spectral_model)

geom = WcsGeom.create(
skydir=position, binsz=0.02, width="5 deg", coordsys="GAL", axes=[energy_axis]
)

t_min = 0 * u.s
t_max = 30000 * u.s

gti = GTI.create(start=t_min, stop=t_max)

dataset = get_map_dataset(
sky_model=skymodel, geom=geom, geom_etrue=geom, edisp=True
)
dataset.gti = gti

return dataset


@requires_data()
def test_MDE_sample_background():
dataset = dataset_maker()
sampler = MapDatasetEventSampler(random_state=0)
bkg_evt = sampler.sample_background(dataset=dataset)

assert len(bkg_evt.table["ENERGY"]) == 375084
assert_allclose(bkg_evt.table["ENERGY"][0], 2.1613281656472028, rtol=1e-5)
assert_allclose(bkg_evt.table["RA"][0], 265.7253792887848, rtol=1e-5)
assert_allclose(bkg_evt.table["DEC"][0], -27.727581635186304, rtol=1e-5)
assert_allclose(bkg_evt.table["MC_ID"][0], 0, rtol=1e-5)

0 comments on commit 90e00d0

Please sign in to comment.