# Jolideco Analysis of Fermi-LAT data of Vela Junior 

In [None]:
import numpy as np
from astropy.io import fits
from astropy.coordinates import SkyCoord
from pathlib import Path
from matplotlib import pyplot as plt
from astropy.visualization import simple_norm
from astropy.wcs import WCS
from astropy import units as u
from jolideco.core import MAPDeconvolver
from jolideco.models import (
    SpatialFluxComponent,
    FluxComponents,
    SparseSpatialFluxComponent,
    NPredModel,
    NPredModels,
    NPredCalibration,
    NPredCalibrations
)
from jolideco.priors import GMMPatchPrior, GaussianMixtureModel, UniformPrior
from jolideco.utils.norms import FixedMaxImageNorm, ASinhImageNorm
from scipy.ndimage import gaussian_filter
from itertools import zip_longest
from gammapy.maps import Map, RegionGeom
from gammapy.datasets import Datasets
from gammapy.modeling.models import Models, SkyModel
from gammapy.modeling import Fit
from gammapy.estimators import TSMapEstimator
import torch
from gammapy.catalog import SourceCatalog3FHL
from regions import CircleSkyRegion

In [None]:
fermi_3fhl = SourceCatalog3FHL()

In [None]:
vela_junior_3fhl = fermi_3fhl["RX J0852.0-4622"]

In [None]:
PATH_BASE = Path("../../results/vela-junior-above-10GeV-data/")

In [None]:
datasets_input = Datasets.read(PATH_BASE / "datasets/vela-junior-above-10GeV-data-datasets-all.yaml")

In [None]:
datasets = Datasets()

for dataset in datasets_input:
    dataset.psf.psf_map.data = dataset.psf.psf_map.data.astype(float)
    dataset.mask_safe = None
    datasets.append(dataset)

In [None]:
models = Models.read(PATH_BASE / "model/vela-junior-above-10GeV-data-model.yaml")
models = models["diffuse-iem"]

In [None]:
datasets.models = models

In [None]:
print(datasets)

## Counts

In [None]:
wcs = datasets[0].counts.geom.wcs

fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    figsize=(12, 12)
)

for ax, dataset in zip(axes.flat, datasets):
    counts = dataset.counts.sum_over_axes()
    counts.plot(ax=ax, cmap="viridis", add_cbar=True)
    ax.set_title(f"{dataset.name}")

## Background

In [None]:
wcs = datasets[0].counts.geom.wcs

fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    figsize=(12, 12)
)

for ax, dataset in zip(axes.flat, datasets):
    npred = dataset.npred().sum_over_axes()
    npred.plot(ax=ax, stretch="linear", cmap="viridis", add_cbar=True)
    ax.set_title(f"{dataset.name}")

## PSF

In [None]:
ax = plt.subplot()

energy_true = np.geomspace(5 * u.GeV, 2000 * u.GeV, 100)

for dataset in datasets:
    radius = dataset.psf.containment_radius(
        position=vela_junior_3fhl.position, energy_true=energy_true, fraction=0.68
    )
    ax.plot(energy_true, radius, label=dataset.name)

ax.set_xlabel("Energy / GeV")
ax.set_ylabel("PSF 68% containment radius / deg")
plt.semilogx()
plt.legend();

# Fit Background Level

In [None]:
exclusion_region = CircleSkyRegion(
    center=vela_junior_3fhl.position,
    radius=1. * u.deg
)

In [None]:
mask = ~dataset.counts.geom.region_mask(exclusion_region)

In [None]:
for d in datasets:
    d.mask_fit = mask

In [None]:
fit = Fit()

In [None]:
%%time
fit.run(datasets)

In [None]:
background_norm = datasets.models["diffuse-iem"].spectral_model.norm.value
print(f"Bkg. norm: {background_norm}")

## Input Datasets

In [None]:
model = SkyModel.create("pl", "point")
model.spectral_model.amplitude.quantity = "1e-14 cm-2 s-1 TeV-1"
model.spectral_model.index.value = 1.7655

print(model)

In [None]:
est = TSMapEstimator(model=model, sum_over_energy_groups=True)

In [None]:
datasets_jolideco = {}

position = SkyCoord.from_name("Vela Junior")

for dataset in datasets:
    dataset.mask_fit = None
    dataset = dataset.to_image(name=dataset.name)
    dataset.models = models

    maps = est.estimate_fit_input_maps(dataset=dataset)
    
    cutout_kwargs = {"position": position, "width": 3. * u.deg}
    
    datasets_jolideco[dataset.name] = {
        "counts": maps["counts"].cutout(**cutout_kwargs).data[0],
        "exposure": maps["exposure"].cutout(**cutout_kwargs).data[0],
        "psf": {"vela-junior": maps["kernel"].cutout(**cutout_kwargs).data[0]},
        "background":  maps["background"].cutout(**cutout_kwargs).data[0],
    }

In [None]:
print(dataset.models)

## Run Jolideco

In [None]:
max_value = 1
norm = ASinhImageNorm(alpha=0.02, beta=max_value)
max_norm = FixedMaxImageNorm(max_value=max_value)

norm.plot(xrange=(0, max_value))
max_norm.plot(xrange=(0, max_value))

In [None]:
gmm = GaussianMixtureModel.from_registry("gleam-v0.2")
gmm.stride = 4
print(gmm)

In [None]:
patch_prior = GMMPatchPrior(
    gmm=gmm,
    jitter=False,
    cycle_spin=True,
    norm=max_norm,
    stride=2,
)

flux_init = 0.1 * np.ones(datasets_jolideco["vela-junior-above-10GeV-data-psf0"]["counts"].shape).astype(np.float32)

component = SpatialFluxComponent.from_numpy(
    flux=flux_init,
    prior=patch_prior,
    use_log_flux=True,
    upsampling_factor=2,
)

components = FluxComponents()
components["vela-junior"] = component

print(components)

In [None]:
calibrations = NPredCalibrations()

for name in datasets.names:
    calibration = NPredCalibration(background_norm=background_norm, frozen=True)
    calibration.shift_xy.requires_grad = False
    calibrations[name] = calibration

print(calibrations)

In [None]:
deconvolve = MAPDeconvolver(n_epochs=500, beta=0.2)
print(deconvolve)

In [None]:
result = deconvolve.run(
    datasets=datasets_jolideco,
    components=components,
    calibrations=calibrations
)

In [None]:
result.plot_trace_loss()

In [None]:
counts = np.sum([_["counts"] for _ in datasets_jolideco.values()], axis=0)

fig, axes = plt.subplots(
    ncols=2,
    subplot_kw={"projection": wcs},
    figsize=(14, 6)
)

norm_asinh = simple_norm(
    counts,
    min_cut=0,
    max_cut=10,
    stretch="asinh",
    asinh_a=0.01
)

axes[0].imshow(gaussian_filter(counts, 3), origin="lower")
axes[0].set_title("Counts")

im = axes[1].imshow(result.components.flux_total_numpy, origin="lower", norm=norm_asinh, interpolation="bicubic")
axes[1].set_title("Deconvolved");
plt.colorbar(im);

In [None]:
plt.imshow(np.sqrt(datasets_jolideco["vela-junior-above-10GeV-data-psf0"]["psf"]["vela-junior"]))

In [None]:
print(calibrations)

In [None]:
result.write("fermi-lat-vela-junior.fits", overwrite=True)

In [None]:
npreds = {}

for name, dataset in datasets_jolideco.items():
    model = NPredModels.from_dataset_numpy(
        dataset=dataset,
        components=result.components,
    )
    
    fluxes = result.components.to_flux_tuple()
    npred = model.evaluate(fluxes=fluxes).detach().numpy()[0, 0]
    npreds[name] = npred
    

npreds_calibrated = {}

for name, dataset in datasets_jolideco.items():
    model = NPredModels.from_dataset_numpy(
        dataset=dataset,
        components=result.components,
        calibration=calibrations[name]
    )
    
    fluxes = result.components.to_flux_tuple()
    npred = model.evaluate(fluxes=fluxes).detach().numpy()[0, 0]
    npreds_calibrated[name] = npred
    

In [None]:
fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    gridspec_kw={"wspace": 0.2},
    figsize=(12, 12)
)


for name, ax in zip_longest(datasets_jolideco, axes.flat):
    if name is None:
        ax.set_visible(False)
        continue
        
    dataset = datasets_jolideco[name]
    
    residual = (dataset["counts"] - npreds[name]) / np.sqrt(npreds[name])
    smoothed = gaussian_filter(residual, 5)
    
    im = ax.imshow(smoothed, vmin=-0.5, vmax=0.5, cmap="RdBu")
    ax.set_title(f"Obs-ID {name}")
    plt.colorbar(im, ax=ax)

In [None]:
fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    gridspec_kw={"wspace": 0.2},
    figsize=(12, 12)
)


for name, ax in zip_longest(datasets_jolideco, axes.flat):
    if name is None:
        ax.set_visible(False)
        continue
        
    dataset = datasets_jolideco[name]
    
    residual = (dataset["counts"] - npreds_calibrated[name]) / np.sqrt(npreds_calibrated[name])
    smoothed = gaussian_filter(residual, 5)
    
    im = ax.imshow(smoothed, vmin=-0.5, vmax=0.5, cmap="RdBu")
    ax.set_title(f"Obs-ID {name}")
    plt.colorbar(im, ax=ax)

In [None]:
print(calibrations)