# Jolideco Analysis of Fermi-LAT data of Vela Junior 

In [None]:
import numpy as np
from pathlib import Path
from matplotlib import pyplot as plt
import matplotlib as mpl
from astropy.visualization import simple_norm
from astropy import units as u
from jolideco.core import MAPDeconvolver
from jolideco.models import (
    SpatialFluxComponent,
    FluxComponents,
    NPredModels,
    NPredCalibration,
    NPredCalibrations
)
from jolideco.priors import GMMPatchPrior, GaussianMixtureModel
from jolideco.utils.norms import IdentityImageNorm
from itertools import zip_longest
from regions import CircleSkyRegion
from gammapy.maps import Maps, WcsGeom, Map

In [None]:
# for vscode dark theme
# plt.style.use('dark_background')
# mpl.rcParams['figure.facecolor'] = '#25292E'


In [None]:
if "snakemake" in globals():
    filenames = list(snakemake.input)
    filename_jolideco_result = snakemake.output.filename_jolideco_result
    filenames_npred = list(snakemake.output.filenames_npred)
else:
    PATH_BASE = Path("../../results/vela-junior-above-10GeV-data/jolideco/")
    filenames = (PATH_BASE / "input").glob("*.fits")
    filename_jolideco_result = PATH_BASE / "vela-junior-above-10GeV-data-result-jolideco.fits"
    filenames_npred = 

In [None]:
datasets = {}

for filename in filenames:
    maps = Maps.read(filename)
    datasets[filename.stem.replace("-maps", "")] = maps


## Counts

In [None]:
wcs = datasets["vela-junior-above-10GeV-data-psf0"]["counts"].geom.wcs

fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    figsize=(12, 12)
)

for ax, (name, maps) in zip(axes.flat, datasets.items()):
    counts = maps["counts"].sum_over_axes()
    counts.plot(ax=ax, cmap="viridis", add_cbar=True)
    ax.set_title(f"{name}")

## Background

In [None]:
fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    figsize=(12, 12)
)

for ax, (name, maps) in zip(axes.flat, datasets.items()):
    background = maps["background"].sum_over_axes()
    background.plot(ax=ax, cmap="viridis", add_cbar=True, stretch="log")
    ax.set_title(f"{name}")

## PSF

In [None]:
wcs = datasets["vela-junior-above-10GeV-data-psf0"]["psf"].geom.wcs

fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    figsize=(12, 12)
)

for ax, (name, maps) in zip(axes.flat, datasets.items()):
    psf = maps["psf"].sum_over_axes()
    psf.plot(ax=ax, cmap="viridis", add_cbar=True, stretch="log")
    ax.set_title(f"{name}")

In [None]:
def to_jolideco_dataset(maps, dtype=np.float32):
    """Convert Gammapy maps to Jolideco dataset."""
    return {
        "counts": maps["counts"].data[0].astype(dtype),
        "background": maps["background"].data[0].astype(dtype),
        "psf": {"vela-junior": maps["psf"].data[0].astype(dtype)},
        "exposure": maps["exposure"].data[0].astype(dtype),
    }

In [None]:
datasets_jolideco = {name: to_jolideco_dataset(maps) for name, maps in datasets.items()}

## Run Jolideco

In [None]:
gmm = GaussianMixtureModel.from_registry("chandra-snrs-v0.1")
gmm.stride = 4
print(gmm)

In [None]:
gmm.plot_mean_images(ncols=16, figsize=(12, 8))

In [None]:
patch_prior = GMMPatchPrior(
    gmm=gmm,
    cycle_spin=True,
    stride=4,
    norm=IdentityImageNorm()
)


shape = datasets_jolideco["vela-junior-above-10GeV-data-psf1"]["counts"].shape
flux_init = np.random.normal(loc=0.1, scale=0.01, size=shape).astype(np.float32)

component = SpatialFluxComponent.from_numpy(
    flux=flux_init,
    prior=patch_prior,
    use_log_flux=True,
    upsampling_factor=1,
)


components = FluxComponents()
components["vela-junior"] = component

print(components)

In [None]:
calibrations = NPredCalibrations()

for name, value in zip(datasets, [0.5, 1.2, 1.2, 1.2]):
    calibration = NPredCalibration(background_norm=value, frozen=False)
    calibrations[name] = calibration

print(calibrations)

In [None]:
deconvolve = MAPDeconvolver(n_epochs=500, learning_rate=0.1)
print(deconvolve)

In [None]:
#datasets_jolideco.pop("vela-junior-above-10GeV-data-psf0")

In [None]:
result = deconvolve.run(
    datasets=datasets_jolideco,
    components=components,
    calibrations=calibrations
)

In [None]:
plt.figure(figsize=(12, 8))
result.plot_trace_loss()

## Results

In [None]:
counts = np.sum([_["counts"] for _ in datasets_jolideco.values()], axis=0)

fig, axes = plt.subplots(
    ncols=2,
    subplot_kw={"projection": wcs},
    figsize=(14, 6)
)

norm_asinh = simple_norm(
    counts,
    min_cut=0.1,
    max_cut=0.5,
    stretch="power",
    power=1.,
)


norm = simple_norm(
    counts,
    min_cut=0.02,
    max_cut=0.3,
    stretch="power",
    power=3
)

im = axes[0].imshow(counts, origin="lower", interpolation="None")
axes[0].set_title("Counts")
plt.colorbar(im);

im = axes[1].imshow(result.components.flux_total_numpy, origin="lower", norm=norm_asinh, interpolation="bicubic")
axes[1].set_title("Deconvolved");
plt.colorbar(im);


In [None]:
print(calibrations)

In [None]:
# could visually compare aginst https://arxiv.org/abs/2303.12686
geom = datasets["vela-junior-above-10GeV-data-psf0"]["counts"].geom.to_image()
flux = Map.from_geom(geom, data=component.flux_numpy)

geom_icrc = WcsGeom.create(
    skydir=geom.center_skydir,
    width=3 * u.deg,
    binsz=0.02
)

flux_icrs = flux.interp_to_geom(geom_icrc)

norm_pwr = simple_norm(
    flux.data,
    min_cut=0,
    max_cut=0.6,
    stretch="power",
    power=1.3,
)
flux_icrs.plot(cmap="cubehelix", norm=norm_pwr)

In [None]:
result.write(filename_jolideco_result, overwrite=True)

## Residuals

In [None]:
npreds = {}

for name, dataset in datasets_jolideco.items():
    model = NPredModels.from_dataset_numpy(
        dataset=dataset,
        components=result.components,
    )
    
    fluxes = result.components.to_flux_tuple()
    npred = model.evaluate(fluxes=fluxes).detach().numpy()[0, 0]
    npreds[name] =  Map.from_geom(data=npred, geom=geom)
    

npreds_calibrated = {}

for name, dataset in datasets_jolideco.items():
    model = NPredModels.from_dataset_numpy(
        dataset=dataset,
        components=result.components,
        calibration=calibrations[name]
    )
    
    fluxes = result.components.to_flux_tuple()
    npred = model.evaluate(fluxes=fluxes).detach().numpy()[0, 0]
    npreds_calibrated[name] = Map.from_geom(data=npred, geom=geom)
    

In [None]:
fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    gridspec_kw={"wspace": 0.2},
    figsize=(12, 12)
)


for name, ax in zip_longest(datasets_jolideco, axes.flat):
    if name is None:
        ax.set_visible(False)
        continue
        
    dataset = datasets[name]
    counts = dataset["counts"].sum_over_axes(keepdims=False).smooth(5)
    npred = npreds[name].smooth(5)
    
    residual = (counts - npred) / np.sqrt(npred)
    
    residual.plot(ax=ax, vmin=-0.5, vmax=0.5, cmap="RdBu", add_cbar=True)
    ax.set_title(f"{name}")

In [None]:
fig, axes = plt.subplots(
    ncols=2,
    nrows=2,
    subplot_kw={"projection": wcs},
    gridspec_kw={"wspace": 0.2},
    figsize=(12, 12)
)


for name, ax in zip_longest(datasets_jolideco, axes.flat):
    if name is None:
        ax.set_visible(False)
        continue
        
    dataset = datasets[name]
    counts = dataset["counts"].sum_over_axes(keepdims=False).smooth(5)
    npred = npreds_calibrated[name].smooth(5)
    
    residual = (counts - npred) / np.sqrt(npred)
    
    residual.plot(ax=ax, vmin=-0.5, vmax=0.5, cmap="RdBu", add_cbar=True)
    ax.set_title(f"{name}")