# Gamma-ray Analysis from Scratch with Jax

This notebook implements a simple high level gamma-ray analysis from scratch using Jax. The goal is to make simple performance comparisons between Jax and Numpy
and gain some experience with Jax.

In [25]:
import numpy as np

import jax
import jax.numpy as jnp
from jax.scipy.optimize import minimize as minimize_jax

from gammapy.datasets import MapDataset
from regions import CircleSkyRegion
from astropy import units as u

from iminuit import minimize as minimize_iminuit
from iminuit import Minuit

jax.config.update("jax_enable_x64", True)

In [2]:
_TINY_FLOAT = np.finfo(float).tiny

FLUX_FACTOR = 1e-8


def _safe_log(x):
    return jnp.log(jnp.maximum(_TINY_FLOAT, x))


def cash_stat(counts, npred):
    """Cash statistic
    
    Implementation copied taken from https://github.com/scikit-hep/iminuit/blob/main/src/iminuit/cost.py#L304
    """
    return 2 * np.sum(npred - counts + counts * (_safe_log(counts) - _safe_log(npred)))


def compute_npred(flux, exposure, background, edisp, psf=None):
    """Computation of predicted number of counts for a given flux"""
    npred = flux * exposure

    # dot sums over the last axis by convention and gammapy uses energy as first axis
    # so we have to transpose the npred array
    npred = jnp.dot(npred.T, edisp).T 
    return npred + background


def integrate_power_law(amplitude, index, x_min, x_max):
    """Integrate power law analytically"""
    return amplitude * (x_max ** (1 - index) - x_min ** (1 - index)) / (1 - index)


def to_jax_dataset_dict(dataset):
    """Convert a `MapDataset` to a dictionary of JAX arrays"""
    edisp_gp = dataset.edisp.get_edisp_kernel()

    energy_edges_true = jnp.array(dataset.exposure.geom.axes["energy_true"].edges)
    
    return {
        "counts": jnp.array(dataset.counts.data),
        "exposure": jnp.array(dataset.exposure.data),
        "background": jnp.array(dataset.background.data),
        "edisp": jnp.array(edisp_gp.data),
        "energy_edges_true": energy_edges_true[:, None, None],
    }
   

def point_source():
    pass


In [3]:
dataset = MapDataset.read("$GAMMAPY_DATA/cta-1dc-gc/cta-1dc-gc.fits.gz", name="dataset-cta")

region = CircleSkyRegion(
    center=dataset.counts.geom.center_skydir,
    radius=0.1 * u.deg
)
spectrum_dataset = dataset.to_spectrum_dataset(on_region=region)

In [4]:
jax_data = to_jax_dataset_dict(spectrum_dataset)

In [46]:
def loss(x0,  counts, exposure, background, edisp, energy_edges_true, psf=None):
    """Loss function to minimize"""
    amplitude, index = x0
    x_min, x_max = energy_edges_true[:-1], energy_edges_true[1:]
    flux = integrate_power_law(FLUX_FACTOR * amplitude, index, x_min=x_min, x_max=x_max)
    npred = compute_npred(flux, exposure, background, edisp)
    stat = cash_stat(counts, npred)
    #jax.debug.print("stat: {} {} {}", stat, amplitude, index)
    return stat


def loss_iminuit(x):
    """Wrapper to use `iminuit` with JAX"""
    return loss(x0=jnp.array(x), **jax_data)

### Jax Fit

In [26]:
result_jax = minimize_jax(loss, x0=jnp.array([0.1, 2]), args=tuple(jax_data.values()), method="BFGS")

In [27]:
print(f"Best fit pars jax: ampl={result_jax.x[0]:.2f}, index={result_jax.x[1]:.2f}")
print(f"Stat val jax: {result_jax.fun:.2f}")

Best fit pars jax: ampl=1.83, index=2.17
Stat val jax: 20.40


### IMinuit Fit

In [30]:
result_minuit = minimize_iminuit(loss, x0=jnp.array([0.1, 2]), args=tuple(jax_data.values()))

In [32]:
print(f"Best fit pars minuit: ampl={result_minuit.x[0]:.2f}, index={result_minuit.x[1]:.2f}")
print(f"Stat val minuit: {result_minuit.fun:.2f}")

Best fit pars minuit: ampl=1.83, index=2.17
Stat val minuit: 20.40


### Performance Comparison

In [34]:
loss_iminuit_jit = jax.jit(loss_iminuit)
loss_jit = jax.jit(loss)

In [39]:
# this triggers the jit compilation
loss_jit((1, 2), **jax_data)
loss_iminuit_jit((1, 2));

In [40]:
%%timeit
loss_jit((1, 2), **jax_data)

8.48 μs ± 52.2 ns per loop (mean ± std. dev. of 7 runs, 100,000 loops each)


In [41]:
%%timeit
loss((1, 2), **jax_data)

247 μs ± 11.4 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)


In [44]:
%%timeit
result = minimize_iminuit(loss, x0=jnp.array([0.1, 2]), args=tuple(jax_data.values()))

22.5 ms ± 757 μs per loop (mean ± std. dev. of 7 runs, 10 loops each)


In [45]:
%%timeit
minuit = Minuit(loss_iminuit_jit, np.array([0.1, 2]))
minuit.migrad()

1.27 ms ± 5.23 μs per loop (mean ± std. dev. of 7 runs, 1,000 loops each)
