In [None]:
import logging
import numpy as np
import astropy.units as u
from astropy.coordinates import Angle, SkyCoord
from astropy.time import Time
from astropy.table import Table
from regions import CircleSkyRegion

# %matplotlib inline
import matplotlib.pyplot as plt

log = logging.getLogger(__name__)

from gammapy.data import DataStore, GTI
from gammapy.datasets import Datasets, SpectrumDataset
from gammapy.datasets.actors import DatasetsActor

from gammapy.estimators import LightCurveEstimator, Estimator
from gammapy.estimators.utils import get_rebinned_axis
from gammapy.makers import (
    ReflectedRegionsBackgroundMaker,
    SafeMaskMaker,
    SpectrumDatasetMaker,
)
from gammapy.maps import MapAxis, RegionGeom, TimeMapAxis
from gammapy.modeling import Fit
from gammapy.modeling.models import PowerLawSpectralModel, SkyModel, BrokenPowerLawSpectralModel, Models, SpectralModel

In [None]:
def time_resolved_spectroscopy(datasets, model, time_intervals):
    fit = Fit()
    valid_intervals = []
    fit_result = []
    index = 0
    for t_min, t_max in time_intervals:
        datasets_to_fit = datasets.select_time(
            time_min=t_min, time_max=t_max
        )

        if len(datasets_to_fit) == 0:
            log.info(
                f"No Dataset for the time interval {t_min} to {t_max}. Skipping interval."
            )
            continue

        model_in_bin = model.copy(name="Model_bin_" + str(index))
        datasets_to_fit.models = model_in_bin
        result = fit.run(datasets_to_fit)
        fit_result.append(result)
        valid_intervals.append([t_min, t_max])
        index += 1

    return valid_intervals, fit_result

def create_table(time_intervals, fit_result):
    col_names = []
    col_unit = []

    col_names.append("time_start")
    col_names.append("time_stop")
    
    
    col_unit.append("MJD")
    col_unit.append("MJD")

    for par in fit_result[0].models.parameters.free_parameters:
        col_names.append(par.name)
        col_names.append(par.name + "_err")
        unt = par.unit
        if unt is u.Unit():
            unt = ""
        col_unit.append(unt)
        col_unit.append(unt)

        
    t = Table(names=col_names, units=col_unit)

    for i in range(len(times)):
        col_data = []
        col_data.append(times[i][0].mjd)
        col_data.append(times[i][1].mjd)

        for name in fit_result[i].parameters.free_parameters.names:
            col_data.append(fit_result[i].models.parameters[name].value)
            col_data.append(fit_result[i].models.parameters[name].error)
        
        t.add_row(col_data)


    
    return t


In [None]:
data_store = DataStore.from_dir("$GAMMAPY_DATA/hess-dl3-dr1/")
target_position = SkyCoord(329.71693826 * u.deg, -30.2255890 * u.deg, frame="icrs")
selection = dict(
    type="sky_circle",
    frame="icrs",
    lon=target_position.ra,
    lat=target_position.dec,
    radius=2 * u.deg,
)
obs_ids = data_store.obs_table.select_observations(selection)["OBS_ID"]
observations = data_store.get_observations(obs_ids)
print(f"Number of selected observations : {len(observations)}")

In [None]:
t0 = Time("2006-07-29T20:30")
duration = 15 * u.min
n_time_bins = 25
times = t0 + np.arange(n_time_bins) * duration
time_intervals = [Time([tstart, tstop]) for tstart, tstop in zip(times[:-1], times[1:])]
print(time_intervals[-1].mjd)

In [None]:
short_observations = observations.select_time(time_intervals)
# check that observations have been filtered
print(f"Number of observations after time filtering: {len(short_observations)}\n")
print(short_observations[1].gti)

In [None]:
# Target definition
energy_axis = MapAxis.from_energy_bounds("0.4 TeV", "20 TeV", nbin=10)
energy_axis_true = MapAxis.from_energy_bounds(
    "0.1 TeV", "40 TeV", nbin=20, name="energy_true"
)

on_region_radius = Angle("0.11 deg")
on_region = CircleSkyRegion(center=target_position, radius=on_region_radius)

geom = RegionGeom.create(region=on_region, axes=[energy_axis])

In [None]:
dataset_maker = SpectrumDatasetMaker(
    containment_correction=True, selection=["counts", "exposure", "edisp"]
)
bkg_maker = ReflectedRegionsBackgroundMaker()
safe_mask_masker = SafeMaskMaker(methods=["aeff-max"], aeff_percent=10)

In [None]:
datasets = Datasets()

dataset_empty = SpectrumDataset.create(geom=geom, energy_axis_true=energy_axis_true)

for obs in short_observations:
    dataset = dataset_maker.run(dataset_empty.copy(), obs)

    dataset_on_off = bkg_maker.run(dataset, obs)
    dataset_on_off = safe_mask_masker.run(dataset_on_off, obs)
    datasets.append(dataset_on_off)

In [None]:
spectral_model = PowerLawSpectralModel(
    index=3.4, amplitude=2e-11 * u.Unit("1 / (cm2 s TeV)"), reference=1 * u.TeV
)
spectral_model.parameters["index"].frozen = False

sky_model = SkyModel(spatial_model=None, spectral_model=spectral_model, name="pks2155")

In [None]:
times, results = time_resolved_spectroscopy(datasets, sky_model, time_intervals)

In [None]:
table = create_table(times, results)

In [None]:
amp = table["amplitude"]
indexes = table["index"]
times = table["time_start"]

In [None]:
plt.scatter(amp, indexes, c=times)
plt.plot(amp, indexes, linewidth=0.5)
plt.show()

In [None]:
spectral_model_bpl = BrokenPowerLawSpectralModel(index1=3, index2=3.8, amplitude=2e-11 * u.Unit("1 / (cm2 s TeV)"), ebreak=1 * u.TeV
)
sky_model_bpl = SkyModel(spatial_model=None, spectral_model=spectral_model, name="pks2155")

In [None]:
times, results_bpl = time_resolved_spectroscopy(datasets, sky_model_bpl, time_intervals)

In [None]:
table_bpl = create_table(times, results_bpl)