In [None]:
import pandas as pd
import numpy as nps
import awkward0 as awkward
import uproot3_methods as uproot_methods
import matplotlib
import matplotlib.pyplot as plt

import torch
import torch.nn as nn

In [None]:
df = pd.read_hdf("test.h5", key="table", start=0, stop=10000)

In [None]:
#based on https://github.com/hqucms/ParticleNet/blob/master/tf-keras/convert_dataset.ipynb
def _col_list(prefix, max_particles=200):
    return ['%s_%d'%(prefix,i) for i in range(max_particles)]

def get_constituents(df):
    _px = df[_col_list('PX')].values
    _py = df[_col_list('PY')].values
    _pz = df[_col_list('PZ')].values
    _e = df[_col_list('E')].values

    mask = _e>0
    n_particles = np.sum(mask, axis=1)

    px = awkward.JaggedArray.fromcounts(n_particles, _px[mask])
    py = awkward.JaggedArray.fromcounts(n_particles, _py[mask])
    pz = awkward.JaggedArray.fromcounts(n_particles, _pz[mask])
    energy = awkward.JaggedArray.fromcounts(n_particles, _e[mask])

    p4 = uproot_methods.TLorentzVectorArray.from_cartesian(px, py, pz, energy)
    jet_p4 = p4.sum()

    eta = jet_p4.eta - p4.eta
    phi = jet_p4.delta_phi(p4)
    pt = p4.pt / jet_p4.pt
    label = df['is_signal_new'].values
    
    return pt, eta, phi, label

pt, eta, phi, label = get_constituents(df)

In [None]:
class Net(torch.nn.Module):
    def __init__(self, num_node_features=3, embed_dim=16):
        super(Net, self).__init__()
        
        self.embed = torch.nn.Linear(num_node_features, embed_dim)
        
        self.norm1 = torch.nn.LayerNorm(embed_dim)
        self.attn = nn.MultiheadAttention(
            embed_dim,
            8,
            dropout=0.0,
            add_bias_kv=False,
            batch_first=True
        )
        self.norm2 = torch.nn.LayerNorm(embed_dim)
        self.out1 = torch.nn.Linear(embed_dim, 128)
        self.out2 = torch.nn.Linear(128, 1)


    def forward(self, x):
        x = self.embed(x)
        x = self.norm1(x)
        
        x_mask = (x_feats[:, :, 0]==0) & (x_feats[:, :, 1]==0)
        x_attn, attention_matrix = self.attn(x,x,x, key_padding_mask=x_mask)
        
        x_mask_f = (~x_mask).to(dtype=torch.float32).unsqueeze(axis=-1)
        x = x+x_attn*x_mask_f
        x = self.norm2(x)*x_mask_f
                
        x_sum = torch.sum(x, axis=-2)
        
        x_sum = torch.selu(self.out1(x_sum))
        out = torch.sigmoid(self.out2(x_sum))
        
        return x, attention_matrix, out

In [None]:
n = Net()

In [None]:
from torch.nn.utils.rnn import pad_sequence

In [None]:
x_feats = [torch.tensor(np.stack([pt[i],eta[i],phi[i]], axis=-1)) for i in range(2000)]
y_vals = torch.stack([torch.tensor(label[i], dtype=torch.float32) for i in range(2000)])
x_feats = pad_sequence(x_feats, batch_first=True)

In [None]:
x_feats.shape

In [None]:
y_vals.sum()

In [None]:
x, attention_matrix, out = n(x_feats)

In [None]:
plt.imshow(x_feat.cpu().numpy()[0], cmap="Blues")

In [None]:
plt.imshow(x.detach().cpu().numpy()[0], cmap="bwr", norm=matplotlib.colors.Normalize(vmin=-2,vmax=2))
plt.colorbar()

In [None]:
plt.imshow(attention_matrix.detach().cpu().numpy()[2], cmap="bwr", norm=matplotlib.colors.Normalize(vmin=-0.1,vmax=0.1))
plt.colorbar()

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

model = Net(embed_dim=256).to(device)

optimizer = torch.optim.Adam(model.parameters(), lr=1e-4)

model.train()
losses_train = []

for epoch in range(20):
    
    loss_train_epoch = []
    
    x_feats = x_feats.to(device)

    optimizer.zero_grad()
    out = model(x_feats)[2]
    loss = torch.nn.functional.binary_cross_entropy(out[:, 0], y_vals)

    loss.backward()
    loss_train_epoch.append(loss.item())
    optimizer.step()
        
    loss_train_epoch = np.mean(loss_train_epoch)
    losses_train.append(loss_train_epoch)
    print(epoch, loss_train_epoch)

In [None]:
x_attn, attn_matrix, out = model(x_feats)

In [None]:
plt.imshow(attn_matrix.detach().cpu().numpy()[8], cmap="bwr", norm=matplotlib.colors.Normalize(vmin=-0.1,vmax=0.1))
plt.colorbar()

In [None]:
b = np.linspace(0,1,21)
plt.hist(out[y_vals==1].detach().numpy(), bins=b, histtype="step", lw=2);
plt.hist(out[y_vals==0].detach().numpy(), bins=b, histtype="step", lw=2);
plt.yscale("log")