# Import modules

In [None]:
%matplotlib inline
import matplotlib.pyplot as plt
from matplotlib import rcdefaults
from matplotlib.gridspec import GridSpec
from cycler import cycler

import numpy as np
from regions import CircleSkyRegion, PointSkyRegion
from pathlib import Path
import os
import operator
import pickle
import mplhep as hep
from gammapy.stats import WStatCountsStatistic

In [None]:
from gammapy.data import DataStore
from gammapy.datasets import (
    Datasets,
    FluxPointsDataset,
    SpectrumDataset,
    SpectrumDatasetOnOff,
)
from gammapy.estimators import FluxPointsEstimator, LightCurveEstimator, FluxPoints
from gammapy.makers import (
    ReflectedRegionsBackgroundMaker,
    SafeMaskMaker,
    SpectrumDatasetMaker,
    WobbleRegionsFinder,
)
from gammapy.maps import MapAxis, RegionGeom, WcsGeom
from gammapy.maps import TimeMapAxis
TimeMapAxis.time_format = "mjd"

from gammapy.modeling import Fit
from gammapy.modeling.models import (
    PowerLawSpectralModel,
    LogParabolaSpectralModel,
    ExpCutoffPowerLawSpectralModel,
    CompoundSpectralModel,
    BrokenPowerLawSpectralModel,
    SmoothBrokenPowerLawSpectralModel,
    EBLAbsorptionNormSpectralModel,
    create_crab_spectral_model,
    SkyModel,
)
from gammapy.visualization import plot_spectrum_datasets_off_regions

from astropy.io import fits
from astropy.coordinates import SkyCoord
from astropy.table import Table, QTable
from astropy.time import Time
import astropy.units as u

from scipy.stats import chi2, norm

# Load the data

In [None]:
dir_path = "path/to/dl3/dir"

In [None]:
total_datastore = DataStore.from_dir(dir_path)

#ogip_path = Path(dir_path + 'OGIP/')
plot_path = Path(dir_path + "plots/")
plot_dataset_path = Path(dir_path + 'plots/datasets/')

# Create the Paths if they do not exist already
#ogip_path.mkdir(exist_ok=True)
plot_path.mkdir(exist_ok=True)
plot_dataset_path.mkdir(exist_ok=True)

In [None]:
total_obs_list = total_datastore.obs_table["OBS_ID"].data
observations_total = total_datastore.get_observations(
    total_obs_list, 
    required_irf=["aeff", "edisp", "rad_max"], # By default, "all" = ["aeff", "edisp", "bkg", "psf"]. 
                        # If not all IRFs are present, the entry will be skipped 
    skip_missing=False # Skip missing observations, within the list provided earlier
)

# Checking the DL3 cut distribution

In [None]:
ax = observations_wob[0].rad_max.plot_rad_max_vs_energy()
plt.grid()

for o in observations_wob:
    o.rad_max.plot_rad_max_vs_energy()
ax.legend().remove()
ax.set_ylim(0.1, 0.4)
ax.yaxis.set_major_formatter("{x:.2f}")

# Define target position

In [None]:
target_position = SkyCoord.from_name(obj_name, frame='icrs')

# Enegy ranges

In [None]:
e_reco_min = 0.01
e_fit_min = 0.04
e_reco_max = 100
e_fit_max = 10

e_true_min = 0.01
e_true_max = 100

# Using bins per decade
e_fit_bin_p_dec = e_reco_bin_p_dec = 5
e_true_bin_p_dec = 5

energy_fit_edges = MapAxis.from_energy_bounds(
    e_fit_min, e_fit_max, 
    nbin=e_fit_bin_p_dec, per_decade=True, 
    unit="TeV"
).edges

energy_axis = MapAxis.from_energy_bounds(
    e_reco_min, e_reco_max, 
    nbin=e_reco_bin_p_dec, per_decade=True, 
    unit="TeV", name="energy"
)
energy_axis_true = MapAxis.from_energy_bounds(
    e_true_min, e_true_max, 
    nbin=e_true_bin_p_dec, per_decade=True, 
    unit="TeV", name="energy_true"
)

# Select minimum and maximum energy edges for LC estimation, from the energy_axis to be used un the Dataset
e_lc_min = 100*u.GeV
e_lc_max = energy_axis.edges[-1]

print(energy_axis.edges)
print(energy_fit_edges)
print(e_lc_min, e_lc_max)

# Creating time intervals for plotting Light Curves

In [None]:
%%time
# Get the GTI parameters of each observations to create time intervals for plotting LC
t_start = []
t_stop = []
tot_time = []

t_start_short = []
t_stop_short = []

for obs in observations_wob:
    gti = obs.gti
    t_start.append(gti.time_start[0])
    t_stop.append(gti.time_stop[0])
    tot_time.append(gti.time_sum.value)
        
t_start = np.sort(np.array(t_start))
t_stop = np.sort(np.array(t_stop))
tot_time = np.array(tot_time)

t_start = Time(t_start)
t_stop = Time(t_stop)

t_day = np.unique(np.rint(t_start.mjd))

# To make the range night-wise, keep the MJD range in half integral values
t_range = [Time([t-0.5, t+0.5], format="mjd", scale="utc") for t in t_day]

# Basemap geometries

In [None]:
on_region = PointSkyRegion(target_position)  ## Need to use PointSkyRegion for energy-dependant theta cut!!!

# This will create the base geometry in which to bin the events based on their reconstructed positions
on_geom = RegionGeom.create(
    on_region, 
    axes=[energy_axis]
)

# Data Reduction

In [None]:
# Create some Dataset and Data Reduction Makers
# geom is the target geometry in reco energy for counts and background maps
# energy_axis_true is the true energy axis for the IRF maps
dataset_empty = SpectrumDataset.create(
    geom=on_geom, 
    energy_axis_true=energy_axis_true
)
# When not including a PSF IRF, put the containment_correction as False
dataset_maker = SpectrumDatasetMaker(
    containment_correction=False, 
    selection=["counts", "exposure", "edisp"]
)

In [None]:
# The following makers can be tuned and played to check the final Dataset to be used.
# Select the necessary number and size of the OFF regions, to be chosen by this method

wobble_off_regions = 3

region_finder = WobbleRegionsFinder(n_off_regions=wobble_off_regions)
bkg_maker = ReflectedRegionsBackgroundMaker(region_finder)

In [None]:
%%time
# The final object will be stored as a Datasets object
datasets = Datasets()
i=0
for obs_id, observation in zip(wob_obs_list.data, observations_wob):
    dataset = dataset_maker.run(
        dataset_empty.copy(name=str(obs_id)), 
        observation
    )
    print(i, 'obs_id:', obs_id)
    
    i+=1
    dataset_on_off = bkg_maker.run(
        dataset=dataset, 
        observation=observation
    )
    
    datasets.append(dataset_on_off)

In [None]:
plt.figure()
#ax = counts.plot(cmap="viridis")
ax = exclusion_mask.plot()
on_geom.plot_region(ax =ax,kwargs_point={"color": "k", "marker": "*"})
plot_spectrum_datasets_off_regions(ax=ax,datasets=datasets)

# Plots

In [None]:
# Plot temporal evolution of excess events and significance value
plt.figure(figsize=(18,5))
plt.subplot(131)
plt.plot(
    np.sqrt(info_table["livetime"].to("h")), info_table["counts"], 
    marker="o", ls="none", label="counts"
)
plt.plot(
    np.sqrt(info_table["livetime"].to("h")), info_table["background"], 
    marker="o", ls="none", alpha=0.4, label="background"
)
#plt.plot(info_table["livetime"].to("h")[-1:1], info_table["excess"][-1:1], 'r')
plt.xlabel("Sqrt Livetime h^(1/2)")
plt.ylabel("Counts")
plt.grid()
plt.legend()
plt.yscale("log")
plt.title('Counts/Background vs Square root of Livetime')

plt.subplot(132)
plt.plot(
    np.sqrt(info_table["livetime"].to("h")), info_table["excess"], marker="o", ls="none"
)
#plt.plot(info_table["livetime"].to("h")[-1:1], info_table["excess"][-1:1], 'r')
plt.xlabel("Sqrt Livetime h^(1/2)")
plt.ylabel("Excess")
plt.grid()
plt.title('Excess vs Square root of Livetime')

plt.subplot(133)
plt.plot(
    np.sqrt(info_table["livetime"].to("h")),
    info_table["sqrt_ts"],
    marker="o",
    ls="none",
)
plt.grid()
plt.xlabel("Sqrt Livetime h^(1/2)")
plt.ylabel("sqrt_ts")
plt.title('Significance vs Square root of Livetime')
plt.subplots_adjust(wspace=0.5)

In [None]:
## Shold be able to check the lima sig and s/b stuff here?

# Further Analysis

In [None]:
stacked_dataset = datasets.stack_reduce()