# Hands on : introduction to NN on HEP dataset using JAX

### Many thanks to _Rafael Coelho Lopes De Sa, Fernando Torales Acosta, David Rousseau, Yann Coadou_, and _Aishik Gosh_, and others for help with this!

## Import Packages

In [2]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Because we like cool progress bars
from tqdm.notebook import tqdm

import jax

%matplotlib inline
import time
pd.set_option('display.max_columns', None) # to see all columns of df.head()

In [3]:
print(jax.__version__)

0.4.28


JAX can use GPU's automatically (almost no change to code below). However, for this tutorial we will just use the CPU!

# Load events

The data was created from ATLAS Open Data. Lets load it in! Jim's lectures will show us more of these details. We have a simple function that will load and clean up the data a bit. It will return it in a `pandas.DataFrame`.

* Feel free to inspect what it does in the `extra_function.py` file!

In [4]:
from extra_functions import load_training_file
all_data = load_training_file()

And a quick look at a few things, so we can "see" what our data looks like. FIrst - how many rows are there?

In [5]:
# get the length of the data
len(all_data)

567749

What is the list of columns that are in this dataset?

In [6]:
# List of columns
all_data.columns

Index(['index', 'eventNumber', 'label', 'met_et', 'met_phi', 'lep_n',
       'lep_pt_0', 'lep_pt_1', 'lep_eta_0', 'lep_eta_1', 'lep_phi_0',
       'lep_phi_1', 'lep_E_0', 'lep_E_1', 'lep_charge_0', 'lep_charge_1',
       'lep_type_0', 'lep_type_1', 'jet_n', 'jet_pt_0', 'jet_pt_1',
       'jet_eta_0', 'jet_eta_1', 'jet_phi_0', 'jet_phi_1', 'jet_E_0',
       'jet_E_1', 'mcWeight', 'runNumber', 'channelNumber'],
      dtype='object')

Note we have lep 0 and 1, and jet 0 and jet 1. But they are unrolled (not, as you will learn, an `awkward` array).

The `label` column tells us if it is signal (`1`) or background (`0`). And we can make some plots. You'll learn more about plotting later in the week.

Lets print out the length of signal and background - to make sure we have enough to test of each.

In [8]:
# Length of signal and background
len(all_data[all_data['label']==1]), len(all_data[all_data['label']==0])

(390398, 177351)

One more very important prep step is to *shuffle* the data before we use it.

* If you train on sub-samples this assures there is a good mix.
* Files are often built by putting signal first and background second - meaning all the events some in order.

In [9]:
all_data = all_data.sample(frac=1).reset_index(drop=True)

Now - lets look at the data and see, visually, how it looks. You'll learn later this week how to use `matplotlib` - the code below is very crude, but it looks at all the variables!

In [None]:
plt.figure()

ax=all_data[all_data.label==0].hist(weights=all_data.mcWeight[all_data.label==0],figsize=(15,12),color='b',alpha=0.5,density=True,bins=50,grid=False)
ax=ax.flatten()[:all_data.shape[1]] # to avoid error if holes in the grid of plots (like if 7 or 8 features)
all_data[all_data.label==1].hist(weights=all_data.mcWeight[all_data.label==1],figsize=(15,12),color='r',alpha=0.5,density=True,ax=ax,bins=50,grid=False)

plt.show()

What look like good variables? Do they make physics sense?

What variables make no physics sense to train on?

Also note that the phi's aren't that different (as expected due to the symmetry of the beamline/detector) So, lets start with a safe set. You can come back later and modify this list if you want!

Lets start with a sub-set of columns to make this easy:  `["met_et","met_phi","lep_pt_0","lep_pt_1",'lep_phi_0', 'lep_phi_1']`

Create the variable `data` in the next cell which only has those columns from `all_data`.

In [12]:
# Just use the columns we want to train against.
data = all_data.loc[:, ["met_et","met_phi","lep_pt_0","lep_pt_1",'lep_phi_0', 'lep_phi_1']]

#### Feature engineering

Besides adding in variables like above, we can also create new variables using our physics knowledge. For example, we know that the open angle between the two leptons can be a good discriminator. The NN might be able to learn this - but since we know, we might as well help it out.

We'll leave this protected for now so that you can come back and try this out later.

In [13]:
use_delta_phi = False
if use_delta_phi: 
    data["lep_deltaphi"]=np.abs(np.mod(data.lep_phi_1-data.lep_phi_0+3*np.pi,2*np.pi)-np.pi)

And a quick look at the variables...

In [14]:
data.head()

Unnamed: 0,met_et,met_phi,lep_pt_0,lep_pt_1,lep_phi_0,lep_phi_1
0,47.684,-0.77197,37.166,28.309,-2.7173,-1.4484
1,92.215,-0.20048,106.71,22.074,-0.1741,2.1569
2,42.411,1.172,27.401,15.549,-0.85048,-1.4688
3,34.865,0.77953,40.732,7.4211,-3.1212,-0.68055
4,46.508,-0.6176,73.652,35.769,0.29708,2.6752


# Training & Testing Samples

First we need to split the data into test and training samples. `scikit-learn` has some great utilities that make this a breeze.

First thing to consider - what fraction do we want to use for training vs testing? Make this a small number to speed training (e.g. small number of samples to train on) and a larger number for more accurate training.

In [15]:
train_size = 0.1

Traditionally `X` is the data we train on, `Y` is _ground truth_ - what we are aiming for, and `weights` are per event weights (usually from MC).

We also split things into `test` and `train` for testing and training. The training sample will be biased by the training, of course, so we keep the `test` sample independent.

In [16]:
X = data
y = all_data.label
weights = all_data.mcWeight

print(f"X shape: {X.shape}, Y shape: {y.shape}, weights shape: {weights.shape}")

X shape: (567749, 6), Y shape: (567749,), weights shape: (567749,)


Make sure that everything is the same length! If not, very bad things will happen below!

Ok - next we can use a very useful library routine, `train_test_split` to split everything up. Look up on the internet how to use this, and then code up the below to generate `X_train, X_test, y_train, y_test, weights_train, weights_test` using the fraction `train_size`. Finally, print out everyone's shape so that we can make sure we aren't making an obvious mistake.

In [19]:
from sklearn.model_selection import train_test_split

X_train, X_test, y_train, y_test, weights_train, weights_test = train_test_split(X, y, weights, train_size=train_size)
print(X_train.shape, X_test.shape)

(56774, 6) (510975, 6)


Often, when doing a real ML training, you'll want to split the test dataset in half - for a test and validation datasets:

- __Training Dataset:__ The sample of data used to fit the model.
- __Validation Dataset:__ The sample used to provide an unbiased evaluation of a model fit on the training dataset while tuning  hyperparameters.
- __Test Dataset:__ The sample of data used to provide an unbiased evaluation of a final model fit on the training dataset.

Why are we so worried about bias?

## Converting to the JAX data arrays

JAX has its own data array types. This is because it wants to be able to work on both your CPU and your GPU - or even remotely. As a result it has a concept of a `DeviceArray` - something `numpy` does not need.

In [20]:
import jax.numpy as jnp

X_train = jnp.array(X_train)
X_test = jnp.array(X_test)
y_train = jnp.array(y_train)
y_test = jnp.array(y_test)
weights_train = jnp.array(weights_train)
weights_test = jnp.array(weights_test)

# Building a JAX Neural Network

First, lets build the JAX NN as we have done in previous efforts. Think a little bit about what we want the output to look like - we want it to be a zero if it is background and a 1 if it is signal. It should never go beyond that.

Use the JAX example from earlier today to figure out how to code this up. The final Haiku network should be called `net`

In [21]:
import jax
import jax.numpy as jnp
import jax.nn
import haiku as hk
from optax import adam, apply_updates
def net_fn(x):
    mlp = hk.Sequential([
    hk.Linear(12), jax.nn.relu,
    hk.Linear(60), jax.nn.relu,
    hk.Linear(32), jax.nn.relu,
    hk.Linear(1), jax.nn.sigmoid
    ])
    return mlp(x)
net = hk.transform(net_fn)

Next, randomly initalize the parameters we'll be training on.

In [22]:
rng = jax.random.PRNGKey(42)
params = net.init(rng, X_train)

The optimizer we'll use - default to `adam`.

In [23]:
import optax
optimizer = optax.adam(0.0005)

And the loss function. Lets use the same one we did before.

In [24]:
from optax import sigmoid_binary_cross_entropy

@jax.jit
def loss_fn(params, x, y):
    preds = net.apply(params, rng, x)
    # This next line provides a x10 speed up.
    preds = preds.reshape(-1)
    sm_array = sigmoid_binary_cross_entropy(preds, y)
    return jnp.mean(sm_array)


Finally, the update function that will generate an update of the code. Note that we should JIT the `update` function to speed it up. See the demo from this morning.

In [25]:
@jax.jit
def train_step(params, opt_state, x, y):
    grads = jax.grad(loss_fn)(params, x, y)
    updates, opt_state = optimizer.update(grads, opt_state)
    new_params = apply_updates(params, updates)
    return new_params, opt_state

In [None]:
opt_state = optimizer.init(params)

losses_training = []
losses_test = []

# How much of the training sample to calculate the loss on.
# Keep it large enough to be meaningful, but small enough to be fast.
n_loss = 10000

The training loop, finally! Note that we've designed this so that we can keep re-running it... so if the first 10 aren't enough, we can just re-run the loop. Write a training loop:

* Uses `tqdm` to run over 1000 epochs
* runs a `train_step` each one.
* Calculates the loss on the test and training sample, and adds them to `losses_training` and `losses_test`
* We should be able to run the cell multiple times without it erasing prior work.

In [None]:
...

And lets plot the training losses to see how well we did!

In [None]:
plt.plot(losses_training,label="training loss")
plt.plot(losses_test,label="test loss")

plt.legend(fontsize=15)
plt.show()

## Evaluating the Training

Evaluate the model based on predictions made with X_test $\rightarrow$ y_test

In [None]:
y_pred_test = net.apply(params, rng, X_test)
y_pred_train = net.apply(params, rng, X_train)

### ROC curves and Area Under the Curve (AUC)

In [None]:
from sklearn.metrics import roc_auc_score # for binary classification if x > 0.5 -> 1 else -> 0
from sklearn.utils import class_weight # to set class_weight="balanced"

In [None]:
from sklearn.metrics import roc_curve
fpr,tpr,_ = roc_curve(y_true=y_test, y_score=y_pred_test,sample_weight=weights_test)
plt.plot(fpr, tpr, color='blue',lw=2)
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic')

In [None]:
auc_test = roc_auc_score(y_true=y_test, y_score=y_pred_test,sample_weight=weights_test)
auc_train = roc_auc_score(y_true=y_train, y_score=y_pred_train,sample_weight=weights_train)
print("auc test:",auc_test)
print ("auc train:",auc_train)

### Plotting NN Score for Signal and Background

In [None]:
from extra_functions import compare_train_test
compare_train_test(y_pred_train.reshape(-1), y_train, y_pred_test.reshape(-1), y_test, 
                   xlabel="NN Score", title="NN", 
                   weights_train=weights_train, weights_test=weights_test)