In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
!pip install uproot

In [None]:
!pip install --target=$nb_path -q torch-scatter -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
!pip install --target=$nb_path -q torch-sparse -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html
!pip install --target=$nb_path -q torch-geometric

In [None]:
!pip install --target=$nb_path -q torch-cluster -f https://pytorch-geometric.com/whl/torch-1.7.0+cu101.html

In [None]:
import torch
import torch.nn as nn
import torch.nn.functional as F

import torch_geometric.nn as pyg_nn
import torch_geometric.utils as pyg_utils

import time
from datetime import datetime

import networkx as nx
import numpy as np
import torch
import torch.optim as optim

from torch_geometric.datasets import TUDataset
from torch_geometric.datasets import Planetoid
from torch_geometric.data import DataLoader

import torch_geometric.transforms as T

from sklearn.manifold import TSNE
import matplotlib.pyplot as plt
import pickle
import torch_cluster
from torch_geometric.nn import GravNetConv

In [None]:
# getting the data

with open('/content/drive/MyDrive/IRIS-HEP_DFEI/list_of_graphs_from_filtered_data_k5_training.pickle', 'rb') as f:
    list_of_graphs_training = pickle.load(f)

with open('/content/drive/MyDrive/IRIS-HEP_DFEI/list_of_graphs_from_filtered_data_k5_testing.pickle', 'rb') as f:
    list_of_graphs_testing = pickle.load(f)

In [None]:
class MyModel(nn.Module):
    def __init__(self):
        super(MyModel, self).__init__()

        self.lin_00 = nn.Linear(in_features=13, out_features=64)
        self.ln_00 = nn.LayerNorm(64)
        self.lin_01 = nn.Linear(in_features=64, out_features=128)
        self.ln_01 = nn.LayerNorm(128)
        self.lin_02 = nn.Linear(in_features=128, out_features=13)
        self.ln_02 = nn.LayerNorm(13)

        self.conv1 = pyg_nn.GCNConv(in_channels=13, out_channels=13)
        self.conv2 = pyg_nn.GCNConv(in_channels=13, out_channels=13)

        self.gravnet1 = GravNetConv(in_channels=13, out_channels=13, space_dimensions=8, propagate_dimensions=7, k=5)
        self.gravnet2 = GravNetConv(in_channels=13, out_channels=13, space_dimensions=8, propagate_dimensions=7, k=5)

        self.lin1 = nn.Linear(in_features=13, out_features=256)
        self.ln1 = nn.LayerNorm(256)

        self.lin2 = nn.Linear(in_features=256, out_features=256)
        self.ln2 = nn.LayerNorm(256)

        self.lin3 = nn.Linear(in_features=256, out_features=1)

        self.relu = F.relu
        self.sigmoid = torch.sigmoid
        
    def forward(self, data):

        x = self.lin_00(data.x)
        x = self.relu(x)
        x = self.ln_00(x)
        x = self.lin_01(x)
        x = self.relu(x)
        x = self.ln_01(x)
        x = self.lin_02(x)
        x = self.relu(x)
        x = self.ln_02(x)
        
        # x = self.gravnet1(x)
        x = self.conv1(x=x, edge_index=data.edge_index)
        x = self.relu(x)
        
        # x = self.gravnet2(x)
        # x = self.conv2(x=x, edge_index=data.edge_index)
        # x = self.relu(x)

        x = self.lin1(x)
        x = self.relu(x)
        x = self.ln1(x)

        x = self.lin2(x)
        x = self.relu(x)
        x = self.ln2(x)

        x = self.lin3(x)

        out = self.sigmoid(x)

        return out

In [None]:
learning_rate = 0.01
num_epochs = 150

device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
model = MyModel()
model = model.to(device)
criterion = torch.nn.BCELoss()
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate)

In [None]:
accs = []
losses = []

for epoch in range(num_epochs):
    correct = 0
    total = 0
    total_loss = 0
    model.train()

    for graph in list_of_graphs_training:

        optimizer.zero_grad()
        graph.to(device)
        preds = model(graph)
        labels = graph.y
        loss = criterion(preds, labels)
        loss.backward()

        optimizer.step()

        total_loss += loss.item()
        correct += sum(abs((torch.round(preds) - labels)))
        total += graph.num_nodes
    
    accuracy = 1 - correct.item() / total
    total_loss /= len(list_of_graphs_training)
    accs.append(accuracy)
    losses.append(total_loss)
    print('Epoch: %d | Loss: %.8f | Train Accuracy: %.8f' \
          %(epoch, total_loss, accuracy))

In [None]:
model_name = "GCNConv"
model_name_cat = "GCNConv"

In [None]:
preds = []
labels = []

for graph in list_of_graphs_testing:
  graph.to(device)
  pred = model(graph)
  label = graph.y
  preds.append(pred)
  labels.append(label)

In [None]:
preds = torch.cat(preds).to("cpu").detach().numpy()
labels = torch.cat(labels).to("cpu").detach().numpy()

In [None]:
plt.title(f"{model_name}")
plt.plot(losses, label="training loss")
plt.plot(accs, label="training accuracy")
plt.legend()

In [None]:
from sklearn.metrics import roc_curve
from sklearn.metrics import auc

fpr_keras, tpr_keras, thresholds_keras = roc_curve(labels, preds)
auc_keras = auc(fpr_keras, tpr_keras)

plt.figure(1)
plt.plot([0, 1], [0, 1], 'k--')
plt.plot(fpr_keras, tpr_keras, label='(AUC = {:.3f})'.format(auc_keras))
plt.xlabel('False positive rate')
plt.ylabel('True positive rate')
plt.title(f'ROC curve ({model_name})')
plt.legend(loc='best')
plt.show()

In [None]:
from sklearn.metrics import precision_recall_curve, f1_score, auc

precision, recall, thresholds = precision_recall_curve(labels, preds)
auc = auc(recall, precision)

no_skill = len(labels[labels==1]) / len(labels)
plt.plot([0, 1], [no_skill, no_skill], linestyle='--', label='No Skill')
plt.plot(recall, precision, marker='.', label=f'{model_name}')
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.text(0.8, 0.8, '(AUC = {:.3f})'.format(auc))
plt.title(f"{model_name} precision-recall curve")
plt.legend()
plt.show()

In [None]:
torch.save(model.state_dict(), f'/content/drive/MyDrive/IRIS-HEP_DFEI/{model_name_cat}_150epochs_first_save.pickle')

In [None]:
with open(f'/content/drive/MyDrive/IRIS-HEP_DFEI/{model_name_cat}_150epochs_first_save_accs.pickle', 'wb') as handle:
    pickle.dump(accs, handle, protocol=pickle.HIGHEST_PROTOCOL)

with open(f'/content/drive/MyDrive/IRIS-HEP_DFEI/{model_name_cat}_150epochs_first_save_losses.pickle', 'wb') as handle:
    pickle.dump(losses, handle, protocol=pickle.HIGHEST_PROTOCOL)