In [1]:
# !pip install uproot

In [2]:
# !pip install --target=$nb_path -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
# !pip install --target=$nb_path -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
# !pip install --target=$nb_path -q torch-geometric

In [3]:
# !pip install --target=$nb_path -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html

In [4]:
# !pip install torch-cluster

In [5]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import time
from datetime import datetime
import tqdm

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.transforms as T

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import pickle
import torch_cluster
from torch_geometric.nn import GravNetConv

In [6]:
# getting the data

with open('/home/jonas/Documents/physics/lhcb/dfei/irishep/data/IRIS-HEP_DFEI/list_of_graphs_from_filtered_data_k10.pickle', 'rb') as f:
    list_of_graphs = pickle.load(f)

In [7]:
frac_test = 0.3
index_test_start = int(len(list_of_graphs) * (1 - frac_test))
list_of_graphs_training = list_of_graphs[:index_test_start]
list_of_graphs_test = list_of_graphs[index_test_start:]

In [8]:
k = 10
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.lin_00 = nn.Linear(in_features=13, out_features=64)
        self.ln_00 = nn.LayerNorm(64)
        self.lin_01 = nn.Linear(in_features=64, out_features=128)
        self.ln_01 = nn.LayerNorm(128)
        self.lin_02 = nn.Linear(in_features=128, out_features=13)
        self.ln_02 = nn.LayerNorm(13)

        self.conv1 = pyg_nn.GCNConv(in_channels=13, out_channels=13)
        self.conv2 = pyg_nn.GCNConv(in_channels=13, out_channels=13)

        self.gravnet1 = GravNetConv(in_channels=13, out_channels=13, space_dimensions=8, propagate_dimensions=7, k=k)
        self.gravnet2 = GravNetConv(in_channels=13, out_channels=13, space_dimensions=8, propagate_dimensions=7, k=k)

        self.lin1 = nn.Linear(in_features=13, out_features=256)
        self.ln1 = nn.LayerNorm(256)

        self.lin2 = nn.Linear(in_features=256, out_features=256)
        self.ln2 = nn.LayerNorm(256)

        self.lin3 = nn.Linear(in_features=256, out_features=1)

        self.relu = F.relu
        self.sigmoid = torch.sigmoid
        
    def forward(self, data):

        x = self.lin_00(data.x)
        x = self.relu(x)
        x = self.ln_00(x)
        x = self.lin_01(x)
        x = self.relu(x)
        x = self.ln_01(x)
        x = self.lin_02(x)
        x = self.relu(x)
        x = self.ln_02(x)
        
        x = self.gravnet1(data.x)
#         x = self.conv1(x=x, edge_index=data.edge_index)
        x = self.relu(x)
        
        x = self.gravnet2(x)
        # x = self.conv2(x=x, edge_index=data.edge_index)
        x = self.relu(x)

        x = self.lin1(x)
        x = self.relu(x)
        x = self.ln1(x)

        x = self.lin2(x)
        x = self.relu(x)
        x = self.ln2(x)

        x = self.lin3(x)

        out = self.sigmoid(x)

        return out

In [9]:
torch.cuda.is_available()

True

In [10]:
learning_rate = 0.01
num_epochs = 300

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MyModel()
model = model.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
accs = []
losses = []

for epoch in range(num_epochs):
    correct = 0
    correct_test = 0
    total = 0
    total_test = 0
    total_loss = 0
    total_loss_test = 0
    

    for graph in tqdm.tqdm(list_of_graphs_training):
        model.train()

        optimizer.zero_grad()
        graph.to(device)
        preds = model(graph)
        labels = graph.y
        loss = criterion(preds, labels)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()
        correct += sum(abs((torch.round(preds) - labels)))
        total += graph.num_nodes
        
    for graph in list_of_graphs_test:
        model.eval()
        graph.to(device)
        preds = model(graph)
        labels = graph.y
        loss = criterion(preds, labels)

        total_loss_test += loss.item()
        correct_test += sum(abs((torch.round(preds) - labels)))
        total_test += graph.num_nodes
        
    
    accuracy = 1 - correct.item() / total
    total_loss /= len(list_of_graphs_training)
    accuracy_test = 1 - correct_test.item() / total_test
    total_loss_test /= len(list_of_graphs_test)
    accs.append(accuracy)
    losses.append(total_loss)
    print('Epoch: %d | Loss: %.8f | Train Accuracy: %.8f | Loss Test: %.8f | Test Accuracy: %.8f' \
          %(epoch, total_loss, accuracy, total_loss_test, accuracy_test))

100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 100.60it/s]


Epoch: 0 | Loss: 0.24197411 | Train Accuracy: 0.94560717 | Loss Test: 0.22222280 | Test Accuracy: 0.94678227


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.29it/s]


Epoch: 1 | Loss: 0.19257607 | Train Accuracy: 0.94752239 | Loss Test: 0.17055855 | Test Accuracy: 0.95069104


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.05it/s]


Epoch: 2 | Loss: 0.17411900 | Train Accuracy: 0.95309311 | Loss Test: 0.17614316 | Test Accuracy: 0.95430267


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 98.33it/s]


Epoch: 3 | Loss: 0.17405637 | Train Accuracy: 0.95328265 | Loss Test: 0.15730427 | Test Accuracy: 0.95530522


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.06it/s]


Epoch: 4 | Loss: 0.16388870 | Train Accuracy: 0.95571652 | Loss Test: 0.15445665 | Test Accuracy: 0.95720163


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:40<00:00, 105.39it/s]


Epoch: 5 | Loss: 0.16296661 | Train Accuracy: 0.95655905 | Loss Test: 0.15657344 | Test Accuracy: 0.95574973


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 98.26it/s]


Epoch: 6 | Loss: 0.16251264 | Train Accuracy: 0.95727348 | Loss Test: 0.15120667 | Test Accuracy: 0.95882505


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 102.83it/s]


Epoch: 7 | Loss: 0.16029168 | Train Accuracy: 0.95757446 | Loss Test: 0.16011102 | Test Accuracy: 0.95882022


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 102.69it/s]


Epoch: 8 | Loss: 0.16533335 | Train Accuracy: 0.95568527 | Loss Test: 0.17222234 | Test Accuracy: 0.95396445


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 100.35it/s]


Epoch: 9 | Loss: 0.15962611 | Train Accuracy: 0.95748073 | Loss Test: 0.15190577 | Test Accuracy: 0.95843369


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 100.48it/s]


Epoch: 10 | Loss: 0.16267959 | Train Accuracy: 0.95714747 | Loss Test: 0.15969424 | Test Accuracy: 0.95581254


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 99.30it/s]


Epoch: 11 | Loss: 0.15976403 | Train Accuracy: 0.95739325 | Loss Test: 0.16019471 | Test Accuracy: 0.95450318


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:44<00:00, 97.74it/s]


Epoch: 12 | Loss: 0.16461731 | Train Accuracy: 0.95616017 | Loss Test: 0.15567129 | Test Accuracy: 0.95801576


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 98.98it/s]


Epoch: 13 | Loss: 0.15989537 | Train Accuracy: 0.95751927 | Loss Test: 0.15713674 | Test Accuracy: 0.95748428


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 100.64it/s]


Epoch: 14 | Loss: 0.15806235 | Train Accuracy: 0.95811185 | Loss Test: 0.15160110 | Test Accuracy: 0.95914635


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 102.00it/s]


Epoch: 15 | Loss: 0.15580276 | Train Accuracy: 0.95862529 | Loss Test: 0.15041033 | Test Accuracy: 0.95892168


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:44<00:00, 97.20it/s]


Epoch: 16 | Loss: 0.15874560 | Train Accuracy: 0.95766090 | Loss Test: 0.15269232 | Test Accuracy: 0.95784423


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 100.11it/s]


Epoch: 17 | Loss: 0.15932912 | Train Accuracy: 0.95763174 | Loss Test: 0.15499209 | Test Accuracy: 0.95918742


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 99.88it/s]


Epoch: 18 | Loss: 0.15687547 | Train Accuracy: 0.95874609 | Loss Test: 0.15227566 | Test Accuracy: 0.95966817


100%|███████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:43<00:00, 98.72it/s]


Epoch: 19 | Loss: 0.15577970 | Train Accuracy: 0.95883670 | Loss Test: 0.15166904 | Test Accuracy: 0.95785148


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.88it/s]


Epoch: 20 | Loss: 0.15711375 | Train Accuracy: 0.95824099 | Loss Test: 0.16514154 | Test Accuracy: 0.95131673


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.57it/s]


Epoch: 21 | Loss: 0.15750728 | Train Accuracy: 0.95857321 | Loss Test: 0.15191419 | Test Accuracy: 0.95990250


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 102.46it/s]


Epoch: 22 | Loss: 0.15831379 | Train Accuracy: 0.95814934 | Loss Test: 0.14899272 | Test Accuracy: 0.95994357


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 104.64it/s]


Epoch: 23 | Loss: 0.15707111 | Train Accuracy: 0.95830556 | Loss Test: 0.15542797 | Test Accuracy: 0.95699870


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 103.76it/s]


Epoch: 24 | Loss: 0.15609657 | Train Accuracy: 0.95891064 | Loss Test: 0.15396459 | Test Accuracy: 0.95874533


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 102.90it/s]


Epoch: 25 | Loss: 0.15749241 | Train Accuracy: 0.95812018 | Loss Test: 0.15051514 | Test Accuracy: 0.96031318


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 104.25it/s]


Epoch: 26 | Loss: 0.15682927 | Train Accuracy: 0.95870756 | Loss Test: 0.17828270 | Test Accuracy: 0.95402002


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 102.73it/s]


Epoch: 27 | Loss: 0.15500879 | Train Accuracy: 0.95909602 | Loss Test: 0.14974452 | Test Accuracy: 0.96008852


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.17it/s]


Epoch: 28 | Loss: 0.15560179 | Train Accuracy: 0.95886482 | Loss Test: 0.15988205 | Test Accuracy: 0.95670881


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.35it/s]


Epoch: 29 | Loss: 0.15734711 | Train Accuracy: 0.95803478 | Loss Test: 0.15568715 | Test Accuracy: 0.95807615


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.52it/s]


Epoch: 30 | Loss: 0.16339674 | Train Accuracy: 0.95599979 | Loss Test: 0.15300454 | Test Accuracy: 0.95769929


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.17it/s]


Epoch: 31 | Loss: 0.15586045 | Train Accuracy: 0.95863987 | Loss Test: 0.14941346 | Test Accuracy: 0.96003054


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 101.94it/s]


Epoch: 32 | Loss: 0.15436986 | Train Accuracy: 0.95925745 | Loss Test: 0.15043973 | Test Accuracy: 0.96004745


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:42<00:00, 100.58it/s]


Epoch: 33 | Loss: 0.15479997 | Train Accuracy: 0.95924495 | Loss Test: 0.15311111 | Test Accuracy: 0.95883471


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 104.20it/s]


Epoch: 34 | Loss: 0.15511626 | Train Accuracy: 0.95917309 | Loss Test: 0.15049521 | Test Accuracy: 0.96000879


100%|██████████████████████████████████████████████████████████████████████████████| 4312/4312 [00:41<00:00, 104.10it/s]


Epoch: 35 | Loss: 0.15465656 | Train Accuracy: 0.95935951 | Loss Test: 0.14994428 | Test Accuracy: 0.96029869


 36%|████████████████████████████                                                   | 1535/4312 [00:15<00:28, 98.29it/s]

In [None]:
# model_name = "GCNConv"
# model_name_cat = "GCNConv"
model_name = f"GravNet_k{K}"
model_name_cat = f"GravNet_k{K}"

In [None]:
preds = []
labels = []

for graph in list_of_graphs_test:
  graph.to(device)
  pred = model(graph)
  label = graph.y
  preds.append(pred)
  labels.append(label)

In [None]:
preds = torch.cat(preds).to("cpu").detach().numpy()
labels = torch.cat(labels).to("cpu").detach().numpy()

In [None]:
plt.title(f"{model_name}")
plt.plot(losses, label="training loss")
plt.plot(accs, label="training accuracy")
plt.legend()

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

fpr_keras, tpr_keras, thresholds_keras = roc_curve(labels, preds)
auc_keras = auc(fpr_keras, tpr_keras)

plt.figure(1)
plt.plot([0, 1], [0, 1], 'k--')
plt.plot(fpr_keras, tpr_keras, label='(AUC = {:.3f})'.format(auc_keras))
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title(f'ROC curve ({model_name})')
plt.legend(loc='best')
plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve, f1_score, auc

precision, recall, thresholds = precision_recall_curve(labels, preds)
auc = auc(recall, precision)

no_skill = len(labels[labels==1]) / len(labels)
plt.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill')
plt.plot(recall, precision, marker='.', label=f'{model_name}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.text(0.8, 0.8, '(AUC = {:.3f})'.format(auc))
plt.title(f"{model_name} precision-recall curve")
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(), f'/home/jonas/Documents/physics/lhcb/dfei/irishep/data/IRIS-HEP_DFEI/{model_name_cat}_150epochs_first_save.pickle')

In [None]:
with open(f'/home/jonas/Documents/physics/lhcb/dfei/irishep/data/IRIS-HEP_DFEI/{model_name_cat}_150epochs_first_save_accs.pickle', 'wb') as handle:
    pickle.dump(accs, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(f'/home/jonas/Documents/physics/lhcb/dfei/irishep/data/IRIS-HEP_DFEI/{model_name_cat}_150epochs_first_save_losses.pickle', 'wb') as handle:
    pickle.dump(losses, handle, protocol=pickle.HIGHEST_PROTOCOL)