## HEP Example: Liklihood gradients

A prime example where this is the case is statistical analysis. For a maximum likelihood fit we want to minimize the log likelihood.

$\theta^* = \mathrm{argmin}_\theta(\log L)$ 

In [None]:
import jax
import jax.numpy as jnp
import numpy as np
import pyhf
import matplotlib.pyplot as plt

plt.rcParams.update(
    {"font.size": 14, "figure.facecolor": (1, 1, 1, 1), "figure.dpi": 100}
)

In [None]:
pyhf.set_backend("jax")

In [None]:
model = pyhf.simplemodels.hepdata_like(
    signal_data=[5.0], bkg_data=[10.0], bkg_uncerts=[3.5]
)
pars = jnp.array(model.config.suggested_init())
observations = [15.0]
data = jnp.array(observations + model.config.auxdata)

best_fit = pyhf.infer.mle.fit(data, model)

In [None]:
def plot_gradient_map(data, model, best_fit, **kwargs):
    fig, ax = plt.subplots()
    fig.set_size_inches(7, 7)

    x_range = kwargs.pop("x_range", (0.5, 1.5))
    y_range = kwargs.pop("y_range", (0.5, 1.5))

    # Countours
    grid = x, y = np.mgrid[
        x_range[0] : x_range[1] : 101j, y_range[0] : y_range[1] : 101j
    ]
    points = np.swapaxes(grid, 0, -1).reshape(-1, 2)

    v = jax.vmap(model.logpdf, in_axes=(0, None))(points, data)
    v = np.swapaxes(v.reshape(101, 101), 0, -1)
    ax.contourf(x, y, v, levels=100)
    ax.contour(x, y, v, levels=20, colors="w")

    # Gradients
    grid = x, y = np.mgrid[x_range[0] : x_range[1] : 11j, y_range[0] : y_range[1] : 11j]
    points = np.swapaxes(grid, 0, -1).reshape(-1, 2)
    values, gradients = jax.vmap(
        jax.value_and_grad(lambda p, d: model.logpdf(p, d)[0]), in_axes=(0, None)
    )(points, data)

    ax.quiver(
        points[:, 0],
        points[:, 1],
        gradients[:, 0],
        gradients[:, 1],
        angles="xy",
        scale=75,
    )
    ax.scatter(best_fit[0], best_fit[1], color="red")

    ax.set_xlabel(r"$\mu$")
    ax.set_ylabel(r"$\theta$")
    ax.set_xlim(x_range)
    ax.set_ylim(y_range)

    fig.tight_layout()

    return fig, ax

In [None]:
fig, ax = plot_gradient_map(data, model, best_fit)
fig.savefig("plots/MLE_grad_map.png")

In [None]:
fig, ax = plot_gradient_map(data, model, best_fit, x_range=(0, 5), y_range=(0, 5))
fig.savefig("plots/MLE_grad_map_full.png")