# Using JAX

Here, we will use JAX to do both the simple method (with the heaviside function) and to explore taking a derivative of this the $S/\sqrt{B}$.

In [1]:
import jax
import jax.numpy as np
import numpy
from samples import data_sig, data_back

In [2]:
# Convert the data to jax arrays
data_sig_j = np.asarray(data_sig)
data_back_j = np.asarray(data_back)



In [3]:
def wts_by_cut(data, cut:float):
    'Calculate weights for a jax array by a cut using the heaviside function, simulating ">"'

    return np.heaviside(np.add(data, -cut), 0)

def sig_sqrt_b(cut):
    'Calculate the S/sqrt(B) for two 1D numpy arrays with the cut at cut.'

    # Weight the data and then do the sum
    wts_sig = wts_by_cut(data_sig_j, cut)
    wts_back = wts_by_cut(data_back_j, cut)

    S = np.sum(wts_sig)
    B = np.sum(wts_back)

    return S/np.sqrt(B)

In [4]:
cut_values = numpy.linspace(-10.0, 10.0, 100)
s_sqrt_b = numpy.array([sig_sqrt_b(c) for c in cut_values])

In [5]:
max_index = np.argmax(s_sqrt_b)
print(rf"Max value of $S\sqrt{{B}}$ occurs at {cut_values[max_index]:0.4} and is {s_sqrt_b[max_index]:0.4}.")

Max value of $S\sqrt{B}$ occurs at 3.737 and is 22.44.


## Gradient

Lets calculate the gradient w.r.t. the cut along the same range to see what that looks like.