# KDE demo, with histosys!

> It works :)

![](assets/kde_pyhf_animation.gif)

In [14]:
#!pip install neos matplotlib celluloid
#!pip install git+http://github.com/scikit-hep/pyhf.git@diffable_json

In [15]:
import time

import jax
import jax.experimental.optimizers as optimizers
import jax.experimental.stax as stax
import jax.random
from jax.random import PRNGKey
import numpy as np
import jax.scipy as jsp
import jax.numpy as jnp
import pyhf

pyhf.set_backend("jax")
pyhf.default_backend = pyhf.tensor.jax_backend(precision="64b")

from neos import data, makers
from relaxed import infer

rng = PRNGKey(22)

In [16]:
# regression net
init_random_params, predict = stax.serial(
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1024),
    stax.Relu,
    stax.Dense(1),
    stax.Sigmoid,
)
# import jax.numpy as jnp
# def inv_sigmoid(x):
#     return jnp.log(x/(1-x))

## Compose differentiable workflow

In [17]:
dgen = data.generate_blobs(rng, blobs=4)

# Specify our hyperparameters ahead of time for the kde histograms
nbins = 5
bins = np.linspace(0, 1, nbins + 1)
# bins = jnp.array([-jnp.inf, *bins, jnp.inf])
# bins = inv_sigmoid(obins)
bandwidth = 0.25
# bandwidth = inv_sigmoid(bandwidth)
reflect_infinite_bins = False  # True
hmaker = makers.hists_from_nn(
    dgen,
    predict,
    hpar_dict=dict(bins=bins, bandwidth=bandwidth),
    method="kde",
    reflect_infinities=reflect_infinite_bins,
)

In [18]:
nnm = makers.histosys_model_from_hists(hmaker)
get_cls = infer.make_hypotest(
    nnm,
    solver_kwargs=dict(pdf_transform=False),
    metrics=["CLs", "pull", "pull_err", "errors"],
)

# loss returns a list of metrics -- let's just index into one (CLs)
def loss(params):
    trics = get_cls(params, test_mu=1.0)
    return jnp.log(trics["CLs"])  # neos
    # return jnp.log(trics["errors"][0])  # inferno
    # return jnp.log(trics["errors"][0]) + 0.1*jnp.log(trics["CLs"])


loss_name = "log(cls)"
# loss name = 'log(sigma_mu)'

### Randomly initialise nn weights and check that we can get the gradient of the loss wrt nn params

In [19]:
# nnm(network)[0].config.suggested_init()
# get_cls(network, test_mu=1.0)["pull"][0]
bandwidth

0.25

In [20]:
_, network = init_random_params(jax.random.PRNGKey(2), (-1, 2))

# gradient wrt nn weights
jax.value_and_grad(loss)(network)

(DeviceArray(-2.71744949, dtype=float64),
 [(DeviceArray([[-0.00011819, -0.00244951,  0.0085968 , ..., -0.00045336,
                 -0.00924117, -0.00115594],
                [ 0.0006744 ,  0.00219908, -0.00393667, ..., -0.00055081,
                  0.00490358,  0.00018834]], dtype=float32),
   DeviceArray([-0.00022378, -0.00014939,  0.00444462, ...,  0.00122349,
                -0.00305481, -0.00034657], dtype=float32)),
  (),
  (DeviceArray([[ 2.20489801e-06,  2.98863524e-05,  2.71029247e-04, ...,
                  1.41429223e-06,  8.76574850e-05, -9.36609285e-06],
                [ 1.33727883e-06, -7.01249428e-06,  1.77396825e-04, ...,
                 -8.26982443e-07,  4.22285630e-05, -1.43008147e-05],
                [-4.20932253e-07, -2.11330072e-04,  6.39680438e-05, ...,
                 -1.04921965e-05, -1.89122555e-04,  5.25775067e-05],
                ...,
                [ 4.18832224e-06,  1.07475651e-04,  4.72832238e-04, ...,
                  5.06049309e-06,  1.93377462e

### Define training loop!

In [21]:
lr = lambda i: 3e-4 if i < 60 else 3e-5

opt_init, opt_update, opt_params = optimizers.adam(lr)


def train_network(N):
    cls_vals = []
    _, network = init_random_params(jax.random.PRNGKey(1), (-1, 2))
    state = opt_init(network)
    losses = []
    np_errs = []
    mu_errs = []

    # parameter update function
    # @jax.jit
    def update_and_value(i, opt_state, mu):
        net = opt_params(opt_state)
        value, grad = jax.value_and_grad(loss)(net)
        return opt_update(i, grad, state), value, net

    for i in range(N):
        start_time = time.time()
        state, value, network = update_and_value(i, state, 1.0)
        epoch_time = time.time() - start_time

        metrs = get_cls(network, test_mu=1.0)
        losses.append(metrs["CLs"])
        mu_errs.append(metrs["errors"][0])
        np_errs.append(metrs["errors"][1])
        metrics = {
            "loss": losses,
            "pull": metrs["pull"],
            "pull_err": metrs["pull_err"],
            "np_err": np_errs,
            "mu_err": mu_errs,
        }

        yield network, metrics, epoch_time

### Plotting helper function for awesome animations :)

In [22]:
[f"[{a:.2g},{b:.2g}]" for a, b in zip(bins[:-1], bins[1:])]

['[0,0.2]', '[0.2,0.4]', '[0.4,0.6]', '[0.6,0.8]', '[0.8,1]']

In [23]:
def make_kde(data, bw):
    @jax.jit
    def get_kde(x):
        return jnp.mean(
            jsp.stats.norm.pdf(x, loc=data.reshape(-1, 1), scale=bw), axis=0
        )

    return get_kde


def bar_plot(ax, data, colors=None, total_width=0.8, single_width=1, legend=True):
    """Draws a bar plot with multiple bars per data point.

    Parameters
    ----------
    ax : matplotlib.pyplot.axis
        The axis we want to draw our plot on.

    data: dictionary
        A dictionary containing the data we want to plot. Keys are the names of the
        data, the items is a list of the values.

        Example:
        data = {
            "x":[1,2,3],
            "y":[1,2,3],
            "z":[1,2,3],
        }

    colors : array-like, optional
        A list of colors which are used for the bars. If None, the colors
        will be the standard matplotlib color cyle. (default: None)

    total_width : float, optional, default: 0.8
        The width of a bar group. 0.8 means that 80% of the x-axis is covered
        by bars and 20% will be spaces between the bars.

    single_width: float, optional, default: 1
        The relative width of a single bar within a group. 1 means the bars
        will touch eachother within a group, values less than 1 will make
        these bars thinner.

    legend: bool, optional, default: True
        If this is set to true, a legend will be added to the axis.
    """

    # Check if colors where provided, otherwhise use the default color cycle
    if colors is None:
        colors = plt.rcParams["axes.prop_cycle"].by_key()["color"]

    # Number of bars per group
    n_bars = len(data)

    # The width of a single bar
    bar_width = total_width / n_bars

    # List containing handles for the drawn bars, used for the legend
    bars = []

    # Iterate over all data
    for i, (name, values) in enumerate(data.items()):
        # The offset in x direction of that bar
        x_offset = (i - n_bars / 2) * bar_width + bar_width / 2

        # Draw a bar for every value of that type
        for x, y in enumerate(values):
            bar = ax.bar(
                x + x_offset,
                y,
                width=bar_width * single_width,
                color=colors[i % len(colors)],
            )

        # Add a handle to the last drawn bar, which we'll need for the legend
        bars.append(bar[0])

    labels = [f"[{a:.1g},{b:.1g}]" for a, b in zip(bins[:-1], bins[1:])]
    ax.set_xticks(range(len(labels)))
    ax.set_xticklabels(labels)

    # Draw legend if we need
    if legend:
        ax.legend(bars, data.keys(), fontsize="x-small")


def plot(axs, axins, network, metrics, maxN, legend=False):
    ax = axs["Data space"]
    g = np.mgrid[-5:5:101j, -5:5:101j]
    if jnp.inf in bins:
        levels = bins[1:-1]  # infinite
    else:
        levels = bins
    ax.contourf(
        g[0],
        g[1],
        predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0],
        levels=levels,
        cmap="binary",
    )
    ax.contour(
        g[0],
        g[1],
        predict(network, np.moveaxis(g, 0, -1)).reshape(101, 101, 1)[:, :, 0],
        colors="w",
        levels=levels,
    )
    sig, bkg_nom, bkg_up, bkg_down = dgen()

    ax.scatter(sig[:, 0], sig[:, 1], alpha=0.3, c="C9", label="signal")
    ax.scatter(
        bkg_up[:, 0], bkg_up[:, 1], alpha=0.1, c="orangered", marker=6, label="bkg up"
    )
    ax.scatter(
        bkg_down[:, 0], bkg_down[:, 1], alpha=0.1, c="gold", marker=7, label="bkg down"
    )
    ax.scatter(bkg_nom[:, 0], bkg_nom[:, 1], alpha=0.3, c="C1", label="bkg")

    ax.set_xlim(-5, 5)
    ax.set_ylim(-5, 5)
    ax.set_xlabel("x")
    ax.set_ylabel("y")
    if legend:
        ax.legend(fontsize="x-small", loc="upper right")
    ax = axs["CLs per epoch"]
    # ax.axhline(0.05, c="slategray", linestyle="--")
    ax.plot(metrics["loss"], c="C9", linewidth=2.0, label=r"$CL_s$")
    ax.set_yscale("log")
    # ax.set_ylim(1e-4, 0.06)
    ax.set_xlim(0, maxN)
    ax.set_xlabel("epoch")
    ax.set_ylabel(r"$CL_s$")
    # if legend:
    #     ax.legend(fontsize="x-small", loc="upper right")

    ax = axs["Uncertainties"]
    ax.plot(
        metrics["np_err"],
        c="slategray",
        linewidth=2.0,
        label=r"$\sigma_{\mathsf{nuisance}}$",
    )
    ax.plot(metrics["mu_err"], c="steelblue", linewidth=2.0, label=r"$\sigma_\mu$")
    # ax.set_ylim(1e-4, 0.06)
    ax.set_xlim(0, maxN)
    ax.set_xlabel("epoch")
    ax.set_ylabel(r"metric value")
    if legend:
        ax.legend(fontsize="x-small", loc="upper right")

    ax = axs["Histogram model"]
    s, b, bup, bdown = hmaker(network)

    noinf = bins[1:-1]
    bin_width = 1 / (len(noinf) - 1)
    centers = noinf[:-1] + np.diff(noinf) / 2.0
    centers = jnp.array([noinf[0] - bin_width, *centers, noinf[-1] + bin_width])

    dct = {
        "signal": s,
        "bkg up": bup,
        "bkg": b,
        "bkg down": bdown,
    }

    bar_plot(
        ax,
        dct,
        colors=["C9", "orangered", "C1", "gold"],
        total_width=0.8,
        single_width=1,
        legend=legend,
    )

    # bunc = np.asarray([[x, y] if x > y else [y, x] for x, y in zip(bup, bdown)])
    # plot_unc = []
    # for unc, be in zip(bunc, b):
    #     if all(unc > be):
    #         plot_unc.append([max(unc), be])
    #     elif all(unc < be):
    #         plot_unc.append([be, min(unc)])
    #     else:
    #         plot_unc.append(unc)

    # plot_unc = np.asarray(plot_unc)
    # b_up, b_down = plot_unc[:, 0], plot_unc[:, 1]

    # ax.bar(centers, bup - b, bottom=b, alpha=0.4, color="red", width=bin_width, hatch="+", label="bkg_up")
    # ax.bar(
    #     centers, b - bdown, bottom=bdown, alpha=0.4, color="green", width=bin_width, hatch="-", label="bkg_down"
    # )

    # ax.set_ylim(0, 120)
    ax.set_ylabel("frequency")
    ax.set_xlabel("interval over nn output")

    ax = axs["Nuisance pull"]

    pulls = metrics["pull"]
    pullerr = metrics["pull_err"]

    ax.set_ylabel(r"$(\theta - \hat{\theta})\,/ \Delta \theta$", fontsize=18)

    # draw the +/- 2.0 horizontal lines
    ax.hlines([-2, 2], -0.5, len(pulls) - 0.5, colors="black", linestyles="dotted")
    # draw the +/- 1.0 horizontal lines
    ax.hlines([-1, 1], -0.5, len(pulls) - 0.5, colors="black", linestyles="dashdot")
    # draw the +/- 2.0 sigma band
    ax.fill_between([-0.5, len(pulls) - 0.5], [-2, -2], [2, 2], facecolor="yellow")
    # drawe the +/- 1.0 sigma band
    ax.fill_between([-0.5, len(pulls) - 0.5], [-1, -1], [1, 1], facecolor="green")
    # draw a horizontal line at pull=0.0
    ax.hlines([0], -0.5, len(pulls) - 0.5, colors="black", linestyles="dashed")

    ax.scatter(range(len(pulls)), pulls, color="black")
    # and their uncertainties
    ax.errorbar(
        range(len(pulls)),
        pulls,
        color="black",
        xerr=0,
        yerr=pullerr[0],
        marker=".",
        fmt="none",
    )

    ax = axs["Example KDE"]
    _, b_data, _, _ = dgen()
    d = np.array(predict(network, b_data).ravel().tolist())
    kde = make_kde(d, bandwidth)
    yields = b
    x = np.linspace(-1, 2, 300)
    db = jnp.array(jnp.diff(bins), float)  # bin spacing
    yields = yields / db / yields.sum(axis=0)  # normalize to bin width
    if jnp.inf in bins:
        pbins = [-999, *noinf, 999]
    else:
        pbins = bins
    ax.stairs(yields, pbins, label="KDE hist", color="C1")
    ax.plot(x, kde(x), label="KDE", color="C0")

    ax.set_xlim(-1, 2)

    # rug plot of the data
    ax.plot(
        d,
        jnp.zeros_like(d) - 0.01,
        "|",
        linewidth=3,
        alpha=0.4,
        color="black",
        label="data",
    )

    if legend:
        width = jnp.diff(bins[1:-1])[0]
        xlim = (
            [(width / 2) - (1.1 * bandwidth), (width / 2) + (1.1 * bandwidth)]
            if (width / 2) - bandwidth < 0
            else [-width / 3, width + width / 3]
        )
        axins.stairs([1], [0, width], color="C1")
        y = jnp.linspace(xlim[0], xlim[1], 300)
        demo = jsp.stats.norm.pdf(y, loc=width / 2, scale=bandwidth)
        axins.plot(y, demo / max(demo), color="C0", linestyle="dashed", label="kernel")
        # draw two vertical lines at ((width/2)-bandwidth)/2 and ((width/2)+bandwidth)/2
        axins.vlines(
            [(width / 2) - bandwidth, (width / 2) + bandwidth],
            0,
            1,
            colors="black",
            linestyles="dotted",
            label=r"$\pm$bandwidth",
        )
        # write text in the middle of the vertical lines with the value of the bandwidth
        ratio = bandwidth / width
        axins.text(
            width / 2,
            -0.3,
            r"$\mathsf{\frac{bandwidth}{bin\,width}}=$" + f"{ratio:.2f}",
            ha="center",
            va="center",
            size="x-small",
        )

        axins.set_xlim(*xlim)

        handles, labels = ax.get_legend_handles_labels()
        handles1, labels1 = axins.get_legend_handles_labels()
        ax.legend(
            handles + handles1, labels + labels1, loc="upper right", fontsize="x-small"
        )

### Let's run it!!

In [24]:
# slow
import numpy as np
from IPython.display import HTML

from matplotlib import pyplot as plt

plt.style.use("default")

plt.rcParams.update(
    {
        "axes.labelsize": 13,
        "axes.linewidth": 1.2,
        "xtick.labelsize": 13,
        "ytick.labelsize": 13,
        "figure.figsize": [16.0, 9.0],
        "font.size": 13,
        "xtick.major.size": 3,
        "ytick.major.size": 3,
        "legend.fontsize": 11,
    }
)

plt.rc("figure", dpi=120)

fig, axs = plt.subplot_mosaic(
    [
        ["Data space", "Histogram model", "Example KDE"],
        ["CLs per epoch", "Uncertainties", "Nuisance pull"],
    ]
)

for label, ax in axs.items():
    ax.set_title(label, fontstyle="italic")
axins = axs["Example KDE"].inset_axes([0.01, 0.79, 0.3, 0.2])
axins.axis("off")
maxN = 100  # make me bigger for better results!

animate = True  # animations fail tests...
ax_cpy = axs
axins_cpy = axins
if animate:
    from celluloid import Camera

    camera = Camera(fig)

# Training
for i, (network, metrics, epoch_time) in enumerate(train_network(maxN)):
    print(
        f"epoch {i}:",
        f'pull={metrics["pull"]}+-{metrics["pull_err"]}',
        f'CLs = {metrics["loss"][-1]}, took {epoch_time}s',
    )
    if animate:
        if i == 0:
            plot(
                axs,
                axins=axins,
                network=network,
                metrics=metrics,
                maxN=maxN,
                legend=True,
            )
            plt.tight_layout()
            camera.snap()
        elif i == maxN - 1:
            plot(axs, axins=axins, network=network, metrics=metrics, maxN=maxN)
            plt.tight_layout()
            camera.snap()
            fig2, axs2 = plt.subplot_mosaic(
                [
                    ["Data space", "Histogram model", "Example KDE"],
                    ["CLs per epoch", "Uncertainties", "Nuisance pull"],
                ]
            )

            for label, ax in axs2.items():
                ax.set_title(label, fontstyle="italic")
            axins2 = axs2["Example KDE"].inset_axes([0.01, 0.79, 0.3, 0.2])
            axins2.axis("off")
            plot(
                axs2,
                axins=axins2,
                network=network,
                metrics=metrics,
                maxN=maxN,
                legend=True,
            )
            plt.tight_layout()
            fig2.savefig(
                f"{loss_name}-{nbins}bins-{bandwidth}bandwidth-{maxN}epochs.pdf"
            )
        else:
            plot(axs, axins=axins, network=network, metrics=metrics, maxN=maxN)
            plt.tight_layout()
            camera.snap()

        axs = ax_cpy
        axins = axins_cpy
        # if i % 10 == 0:
        #     camera.animate().save("animation.gif", writer="imagemagick", fps=8)
        # HTML(camera.animate().to_html5_video())
    plt.close()

epoch 0: pull=[-0.02935935]+-[[0.9267604]] CLs = 0.06614223240644135, took 1.0668590068817139s
epoch 1: pull=[-2.4272476e-06]+-[[0.39255115]] CLs = 0.05681475959127047, took 1.7591149806976318s
epoch 2: pull=[2.80025945e-06]+-[[0.2541726]] CLs = 0.04285568963146935, took 1.7482740879058838s
epoch 3: pull=[2.3988122e-06]+-[[0.20984319]] CLs = 0.029178288482402337, took 1.7594192028045654s
epoch 4: pull=[2.2351657e-06]+-[[0.19426487]] CLs = 0.019101392487009372, took 1.7977440357208252s
epoch 5: pull=[2.17563924e-06]+-[[0.19175002]] CLs = 0.012649010059314403, took 1.8346030712127686s
epoch 6: pull=[2.178135e-06]+-[[0.19710273]] CLs = 0.008728891073170653, took 1.8102209568023682s
epoch 7: pull=[2.23047663e-06]+-[[0.20847263]] CLs = 0.006325636875318397, took 1.7899460792541504s
epoch 8: pull=[2.33650055e-06]+-[[0.22544448]] CLs = 0.004770890515893189, took 1.8206119537353516s
epoch 9: pull=[2.50752464e-06]+-[[0.24806661]] CLs = 0.0037004300786214195, took 1.7789978981018066s
epoch 10: p

<Figure size 1920x1080 with 0 Axes>

In [25]:
if animate:
    camera.animate().save("a3.gif", writer="imagemagick", fps=12)

In [None]:
if animate:
    camera.animate().save("a2.gif", writer="imagemagick", fps=10)

CalledProcessError: Command '['convert', '-size', '1920x1080', '-depth', '8', '-delay', '8.333333333333334', '-loop', '0', 'rgba:-', 'a3.gif']' returned non-zero exit status 2.

In [None]:
import jax
import jax.numpy as jnp


def inv_sigmoid(x):
    return jnp.log(x / (1 - x))


inv_sigmoid(jnp.linspace(0, 1, 4))

DeviceArray([       -inf, -0.69314718,  0.69314718,         inf], dtype=float64)

In [None]:
from IPython.display import HTML

animation = camera.animate()
HTML(animation.to_html5_video())

RuntimeError: Requested MovieWriter (ffmpeg) not available

In [None]:
camera.animate().save("/workspaces/relaxed/aregsoft.gif", writer="imagemagick", fps=10)



In [None]:
camera._offsets

{'collections': defaultdict(int, {0: 19, 1: 0, 2: 0, 3: 0, 4: 16}),
 'patches': defaultdict(int, {0: 0, 1: 0, 2: 40, 3: 2, 4: 0}),
 'lines': defaultdict(int, {0: 0, 1: 6, 2: 0, 3: 4, 4: 0}),
 'texts': defaultdict(int, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0}),
 'artists': defaultdict(int, {0: 2, 1: 2, 2: 2, 3: 2, 4: 0}),
 'images': defaultdict(int, {0: 0, 1: 0, 2: 0, 3: 0, 4: 0})}