**This code was adapted from [Alexander Held's "Example of a differentiable analysis" repository](https://github.com/alexander-held/differentiable-analysis-example/)**

In [None]:
from jax import grad, vmap, jit
import jax.numpy as jnp
import matplotlib.pyplot as plt
import numpy as np

np.random.seed(0)
plt.rcParams.update(
    {
        "font.size": 14,
        "figure.figsize": (7, 5),
        "figure.facecolor": (1, 1, 1, 1),
        "figure.dpi": 100,
        #         "figure.dpi": 300,
    }
)

### Introduction

We have two processes:

- signal `S`,
- background `B`,

and we observe events generated from these processes.
The observation consist of measuring a single observable.
We have Monte Carlo predictions in the form of events generated for both processes, which tell us how many events we observe as a function of the observable $x$.

We want to do a single-bin counting experiment to establish the presence of the signal process.
The significance calculation for this is found below.
To maximize the significance, this toy analysis introduces a cut.
We only let events with an observable $x>c$ enter our bin, where $c$ is a cut value we intend to optimize.

### Data generation and visualization
The background and signal events in this toy case have slightly different values of the observable $x$ on average, but the spread of the background along the observable is quite large. 

In [None]:
nBg = 8000
nSig = 300
background = np.random.normal(40, 10, nBg)
signal = np.random.normal(50, 5, nSig)

bins = jnp.linspace(0, 80, 40)

In [None]:
fig, ax = plt.subplots()
alpha = 1.0
ax.hist(background, bins=bins, alpha=alpha, label=["Background"])
ax.hist(signal, bins=bins, alpha=alpha, label=["Signal"])
ax.set_xlabel(r"Observable $x$")
ax.set_ylabel("Count")
ax.legend(loc="best")
fig.savefig("plots/signal_background_shapes.png")

fig, ax = plt.subplots()
ax.hist([background, signal], bins=bins, stacked=True, label=["Background", "Signal"])
ax.set_xlabel(r"Observable $x$")
ax.set_ylabel("Count")
ax.legend(loc="best")
fig.savefig("plots/signal_background_stacked.png");

### Calculate significance for given cut and scan
This significance is given by the amount of signal and background events in the bin as 
$\sqrt{2(S+B) \log(1+\frac{S}{B}) -2S}$.
Expand $\log(1+y) \approx y-\frac{y^2}{2}$ to recover $\frac{S}{\sqrt{B}}$.

In [None]:
def significance(S, B):
    """calculate the significance according to the formula above"""
    return jnp.sqrt(2 * ((S + B) * jnp.log(1 + S / B) - S))


def get_significance(cut, S, B):
    """calculate the significance at a given cut value for signal and background events"""
    S_cut = len(S[S > cut])
    B_cut = len(B[B > cut])
    return significance(S_cut, B_cut)

Let's calculate as a baseline the significance for when the cut is $c=0$, so we let all events $x>0$ pass the cut.

In [None]:
baseline_significance = get_significance(0, signal, background)
print("the baseline significance is", baseline_significance, "σ")

In this example it is straightforward to just scan for the optimal setting, let's do that before we use any autodiff. 

In [None]:
cuts = jnp.linspace(20, 70, 500)
significances = [get_significance(cut, signal, background) for cut in cuts]

In [None]:
fig, ax = plt.subplots()

ax.plot(cuts, significances, c="C2")
ax.set_xlabel(r"Cut position on $x$")
ax.set_ylabel("Significance")
fig.savefig("plots/significance_cut_scan.png");

Seems like we should use a cut of $x>45$ or so.
Now for the autodiff version.

### Let's calculate a gradient
With `jax` we get the gradient of `get_significance` by just calling `grad()`.
The argument `argnums=0` means that we want the gradient wrt. the first argument, which is the cut value.

In [None]:
grad_significance = grad(get_significance, argnums=0)

# let's calculate the derivative for a few values
significances_prime = [grad_significance(cut, signal, background) for cut in cuts]
plt.plot(cuts, significances_prime, c="C3")
plt.xlabel("cut position c")
plt.ylabel("gradient of significance");

This did not quite work out.
The problem here is that the significance function is actually flat almost everywhere, and the gradient is only non-zero exactly where events are located.
The significance changes as single events leave the bin that we use for the measurement, so it changes at the positions of all the events.
We need something else to properly do what we want.
To confirm the above, let's have a very zoomed in look at the significance as a function of the cut position.

In [None]:
cuts_zoomed = jnp.linspace(60.7, 61.0, 500)  # zoom in to see step function behavior
significances_zoomed = [
    get_significance(cut, signal, background) for cut in cuts_zoomed
]
plt.plot(cuts_zoomed, significances_zoomed, c="C2")
plt.xlabel("cut position c")
plt.ylabel("significance");

### Making it work
Instead of working with the non-differential cut operation, let's replace it by something differentiable.
We give weights to all our events.
Those weights should be `1` in the limit where the events are very far above the cut, and `0` if they are very far below the cut.

We can use a sigmoid $1 / [1+e^{-\alpha(x-c)}]$ to calculate the weight, where $c$ is the cut value and $\alpha$ a parameter that adjust the steepness.
Larger values of $\alpha$ make the function steeper in the transition region around the cut.

In [None]:
def yield_after_cut(x, c, alpha=1):
    """calculate the number of events passing a certain cut"""
    # If alpha is too large -> NaNs later on...
    passed = 1 / (1 + jnp.exp(-alpha * (x - c)))
    return passed

We can visualize the weight as a function of observable for a given cut value.

In [None]:
x = jnp.linspace(30, 60, 100)
example_cut = 45
y_pass = yield_after_cut(x, example_cut, alpha=1)

fig, ax = plt.subplots()
# ax.plot(x, y_pass, c="C4", label=r"$\alpha=1$")
ax.plot(x, yield_after_cut(x, example_cut, alpha=0.5), label=r"$\alpha=0.5$")
ax.plot(x, y_pass, label=r"$\alpha=1$", color="C3")
ax.plot(x, yield_after_cut(x, example_cut, alpha=2), label=r"$\alpha=2$")
ax.plot([example_cut, example_cut], [0, 1], ":", c="k")
ax.set_xlabel(r"Observable $x$")
ax.set_ylabel("Event weight")
ax.set_title("Example cut at $x=45$")
ax.legend(loc="best")
fig.savefig("plots/sigmoid_event_weights.png");

Let's define a new function to calculate the significance, but now instead of doing it by applying cuts we instead approximate cuts with the sigmoid approach and weights.
We can compare this way of calculating the significance to the approach using hard cuts from earlier.

In [None]:
def get_significance_smooth(cut, S, B, alpha=1):
    """calculate significance for a given cut, but now approximate the cut with a sigmoid"""
    S_pass = yield_after_cut(S, cut, alpha)
    B_pass = yield_after_cut(B, cut, alpha)
    sum_S_weighted = jnp.sum(S_pass)
    sum_B_weighted = jnp.sum(B_pass)
    return significance(sum_S_weighted, sum_B_weighted)

In [None]:
cuts_smooth = jnp.linspace(20, 70, 500)
significances_smooth_alpha_1 = jnp.asarray(
    [get_significance_smooth(cut, signal, background, alpha=1) for cut in cuts_smooth]
)
significances_smooth_alpha_half = jnp.asarray(
    [get_significance_smooth(cut, signal, background, alpha=0.5) for cut in cuts_smooth]
)
significances_smooth_alpha_2 = jnp.asarray(
    [get_significance_smooth(cut, signal, background, alpha=2) for cut in cuts_smooth]
)
significances_smooth_alpha_3 = jnp.asarray(
    [get_significance_smooth(cut, signal, background, alpha=3) for cut in cuts_smooth]
)
print("Optimal cut is c =", cuts_smooth[jnp.argmax(significances_smooth_alpha_1)])

In [None]:
fig, ax = plt.subplots()

ax.plot(cuts, significances, label="Hard cuts", c="C2")
ax.plot(cuts_smooth, significances_smooth_alpha_half, label=r"Sigmoid, $\alpha=0.5$")
ax.plot(cuts_smooth, significances_smooth_alpha_1, label=r"Sigmoid, $\alpha=1$", c="C3")
ax.plot(cuts_smooth, significances_smooth_alpha_2, label=r"Sigmoid, $\alpha=2$")
ax.plot(cuts_smooth, significances_smooth_alpha_3, label=r"Sigmoid, $\alpha=3$")
ax.set_xlabel("Cut position $c$")
ax.set_ylabel("Significance")
ax.legend(loc="best")
fig.savefig("plots/significance_scan_compare.png")

fig, ax = plt.subplots()

ax.plot(cuts, significances, label="Hard cuts", c="C2")
ax.plot(cuts_smooth, significances_smooth_alpha_3, label=r"Sigmoid, $\alpha=3$")
ax.set_xlabel("Cut position $c$")
ax.set_ylabel("Significance")
ax.legend(loc="best")
fig.savefig("plots/significance_scan_compare_high_alpha.png")

The smooth version seems to be a decent approximation.
We can in principle improve it by increasing the steepness $\alpha$, but run into some `NaN` issues in the next step (insights welcome of how to solve this, we can already introduce them in the state above be `@jit`-ing the function).

Now let's calculate the gradient of the smooth function above, and evaluate it for a few cut values.

In [None]:
sig_gradient_function = grad(get_significance_smooth, argnums=0)
sig_prime_smooth = jnp.asarray(
    [sig_gradient_function(cut, signal, background, alpha=1) for cut in cuts_smooth]
    #     [sig_gradient_function(cut, signal, background, alpha=2) for cut in cuts_smooth]
)

The big moment has come!
Time to take a look at the results.

In [None]:
def get_intercept_indices(arr):
    """find where an array of values intercepts zero"""
    intercepts = []
    for i in range(len(arr)):
        # check if sign changed, meaning zero was crossed
        if arr[i] * arr[i + 1] < 0:
            # pick side closer to zero and return index
            if jnp.abs(arr[i]) > jnp.abs(arr[i + 1]):
                intercepts.append(i + 1)
            else:
                intercepts.append(i)
    return jnp.asarray(intercepts)

In [None]:
intercepts = get_intercept_indices(sig_prime_smooth)
cut_values = jnp.asarray([cuts_smooth[intercept] for intercept in intercepts])

print(f"intercepts of the gradient with zero are located at {cut_values}")

In [None]:
fig, axs = plt.subplots(2, 1)
fig.set_size_inches(10, 10)
# draw significance again
axs[0].plot(cuts_smooth, significances_smooth_alpha_1, color="C3", label=r"$\alpha$=1")
# plt.plot(cuts_smooth, significances_smooth_alpha_2, color="C3")
axs[0].set_xlabel(r"Cut position $c$")
axs[0].set_ylabel("Significance")
axs[0].legend(loc="best")

# add gradient
axs[1].plot([20, 70], [0, 0], "--", color="grey")
axs[1].plot(cuts_smooth, sig_prime_smooth, color="C3")

xmin, xmax, ymin, ymax = axs[1].axis()

for cval in cut_values:
    axs[1].plot([cval, cval], [ymin, ymax], ":", color="k")
axs[1].set_xlabel(r"Cut position $c$")
axs[1].set_ylabel("Gradient of significance")
fig.tight_layout()

fig.savefig("plots/significance_gradient.png")

We can find the maximum significance at the point where the gradient of the significance intercepts zero (and could in principle check the second derivative to make sure this is not a saddle point).

### Automatic analysis optimization

As a last step, let's write a simple gradient ascent function to find the point of maximum significance.
In the example below this is computationally far less efficient than a scan, but with a better implementation and a higher dimensional problem the approach really shines.

In [None]:
steps = 1_000

c0 = 20.0  # initial position
step_size = 2  # step size

steps_taken = []

for i in range(steps):
    # reduce step size after a while
    if i % 500 == 0 and i != 0:
        step_size = step_size / 5
    grad_at_pos = sig_gradient_function(c0, signal, background)
    if i % 250 == 0:
        print("current position is", c0, "and the gradient is", grad_at_pos)
    c0 = c0 + step_size * grad_at_pos
    if i % 50 == 0:
        steps_taken.append(c0)

print("final position is", c0)

# calculate the significance at a few steps along the way to visualize
sig_at_steps = jnp.asarray(
    [get_significance_smooth(step, signal, background, alpha=1) for step in steps_taken]
    #         [get_significance_smooth(step, signal, background, alpha=2) for step in steps_taken]
)

In [None]:
fig, ax = plt.subplots()
ax.plot(cuts_smooth, significances_smooth_alpha_1, c="C3", label=r"$\alpha=1$")
# ax.plot(cuts_smooth, significances_smooth_alpha_2, c="C3", label=r"$\alpha=2$")
ax.plot(steps_taken, sig_at_steps, "o", c="C4", label="At each 50th step")
ax.set_xlabel(r"Cut position $c$")
ax.set_ylabel("Significance")
ax.legend(loc="best")

fig.savefig("plots/automated_optimization.png")