## Colab setup

In [None]:
# !pip install pandas numpy awkward0 uproot3_methods matplotlib
# !pip3 install torch torchvision torchaudio

# import os
# import torch
# os.environ['TORCH'] = torch.__version__
# print(torch.__version__)

# !pip install -q torch-geometric -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-scatter -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-sparse -f https://data.pyg.org/whl/torch-${TORCH}.html
# !pip install -q torch-cluster -f https://data.pyg.org/whl/torch-${TORCH}.html

In [None]:
import pandas as pd
import numpy as nps
import awkward0 as awkward
import uproot3_methods as uproot_methods
import matplotlib
import matplotlib.pyplot as plt
import tqdm

import torch
import torch.nn as nn
import numpy as np

## Data

In [None]:
!wget -nc -O test.h5 https://zenodo.org/record/2603256/files/test.h5?download=1

In [None]:
df = pd.read_hdf("test.h5", key="table", start=0, stop=10000)

In [None]:
df

In [None]:
#based on https://github.com/hqucms/ParticleNet/blob/master/tf-keras/convert_dataset.ipynb
def _col_list(prefix, max_particles=200):
    return ['%s_%d'%(prefix,i) for i in range(max_particles)]

def get_constituents(df):
    _px = df[_col_list('PX')].values
    _py = df[_col_list('PY')].values
    _pz = df[_col_list('PZ')].values
    _e = df[_col_list('E')].values

    mask = _e>0
    n_particles = np.sum(mask, axis=1)

    px = awkward.JaggedArray.fromcounts(n_particles, _px[mask])
    py = awkward.JaggedArray.fromcounts(n_particles, _py[mask])
    pz = awkward.JaggedArray.fromcounts(n_particles, _pz[mask])
    energy = awkward.JaggedArray.fromcounts(n_particles, _e[mask])

    p4 = uproot_methods.TLorentzVectorArray.from_cartesian(px, py, pz, energy)
    jet_p4 = p4.sum()

    eta = jet_p4.eta - p4.eta
    phi = jet_p4.delta_phi(p4)
    pt = p4.pt / jet_p4.pt
    label = df['is_signal_new'].values
    
    return pt, eta, phi, label

pt, eta, phi, label = get_constituents(df)

## Self-attention based encoder

In [None]:
class Net(torch.nn.Module):
    def __init__(self, num_node_features=3, embed_dim=128):
        super(Net, self).__init__()
        
        self.embed1 = torch.nn.Linear(num_node_features, 128)
        self.embed2 = torch.nn.Linear(128, embed_dim)
        
        self.norm1 = torch.nn.LayerNorm(embed_dim)
        self.attn1 = nn.MultiheadAttention(
            embed_dim,
            8,
            dropout=0.0,
            add_bias_kv=False,
            batch_first=True
        )
        self.norm2 = torch.nn.LayerNorm(embed_dim)
        
        self.out1 = torch.nn.Linear(embed_dim, 128)
        self.out2 = torch.nn.Linear(128, 1)
        

    def forward(self, x):
        x = torch.selu(self.embed1(x))
        x = self.embed2(x)
        x = self.norm1(x)
        
        x_mask = (x[:, :, 0]==0) & (x[:, :, 1]==0)
        x_attn, attention_matrix = self.attn1(x,x,x, key_padding_mask=x_mask)
        
        x_mask_f = (~x_mask).to(dtype=torch.float32).unsqueeze(axis=-1)
        x = x+x_attn*x_mask_f
        x = self.norm2(x)*x_mask_f
        
        x_sum = torch.sum(x, axis=-2)
        
        x_sum = torch.selu(self.out1(x_sum))
        out = torch.sigmoid(self.out2(x_sum))
        
        return x, attention_matrix, out

In [None]:
n = Net()

In [None]:
from torch.nn.utils.rnn import pad_sequence

In [None]:
class Dataset(torch.utils.data.Dataset):
    'Characterizes a dataset for PyTorch'
    def __init__(self, x_feats, y_vals):
        'Initialization'
        self.x_feats = x_feats
        self.y_vals = y_vals

    def __len__(self):
        'Denotes the total number of samples'
        return len(self.x_feats)

    def __getitem__(self, index):

        return x_feats[index], y_vals[index]

In [None]:
x_feats = [torch.tensor(np.stack([pt[i],eta[i],phi[i]], axis=-1)) for i in range(len(df))]
y_vals = torch.stack([torch.tensor(label[i], dtype=torch.float32) for i in range(len(df))])

ds = Dataset(x_feats, y_vals)

In [None]:
x, y = ds[0]

In [None]:
x_out, attention_matrix, out = n(x.unsqueeze(0))

In [None]:
plt.imshow(x.numpy())
plt.colorbar()

In [None]:
plt.imshow(x_out[0].detach().cpu().numpy())
plt.colorbar()

In [None]:
plt.imshow(attention_matrix.detach().cpu().numpy()[0])
plt.colorbar()
plt.xticks(range(len(x_out[0])));
plt.yticks(range(len(x_out[0])));

In [None]:
def collate_fn(inputs):
    return pad_sequence([i[0] for i in inputs], batch_first=True), torch.stack([i[1] for i in inputs])

In [None]:
training_generator = torch.utils.data.DataLoader(ds, batch_size=128, collate_fn=collate_fn)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Net(embed_dim=256).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-6)

model.train()
losses_train = []

for epoch in range(5):
    
    loss_train_epoch = []
    
    for X, y in tqdm.tqdm(training_generator):
        X = X.to(device)
        y = y.to(device)

        optimizer.zero_grad()
        out = model(X)[2]
        loss = torch.nn.functional.binary_cross_entropy(out[:, 0], y)

        loss.backward()
        loss_train_epoch.append(loss.item())
        optimizer.step()

    loss_train_epoch = np.mean(loss_train_epoch)
    losses_train.append(loss_train_epoch)
    print(epoch, loss_train_epoch)
    
model.eval()

In [None]:
x_attn, attn_matrix, out = model(ds[0][0].unsqueeze(0))

In [None]:
plt.imshow(attn_matrix.detach().cpu().numpy()[0])
plt.colorbar()

In [None]:
y_pred = torch.concat([model(d[0])[2][:, 0] for d in training_generator])
y_true = torch.concat([d[1] for d in training_generator])

In [None]:
y_pred.shape, y_true.shape

In [None]:
b = np.linspace(0,1,21)
plt.hist(y_pred[y_true==1].detach().numpy(), bins=b, histtype="step", lw=2);
plt.hist(y_pred[y_true==0].detach().numpy(), bins=b, histtype="step", lw=2);
#plt.yscale("log")

## Exercises

### 1. Add another stacked attention layer, check the performance of the model

### 2. Query the second attention layer with a learnable vector, instead of the encoded elements. Check the performance of the model.  

### 3. Change the model to output a per-particle classification score (e.g. PU rejection in a jet).
