# Fully Connected Nodes

In [119]:
import matplotlib as plt
import pandas as pd
import networkx as nx

import torch
import torch.nn.functional as F
import torch.optim as optim
from torch import Tensor
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch.utils.data import Dataset, DataLoader
from torch_geometric.nn import SAGEConv

from itertools import product

from sklearn import metrics
from sklearn.metrics import accuracy_score, precision_score, recall_score, f1_score
from sklearn.metrics import classification_report

In [35]:
# Load data
signal_df = pd.read_csv('Dijet_bb_pt10_15_dw.csv')
background_df = pd.read_csv('Dijet_qq_pt10_15_dw.csv')

In [36]:
# Separate Jet 0 and Jet 1 data & combine signal/ background
sig_jet0 = signal_df[signal_df.columns[signal_df.columns.str.contains("Jet0")]]
back_jet0 = background_df[background_df.columns[background_df.columns.str.contains("Jet0")]]
train_df = pd.concat([sig_jet0, back_jet0])

sig_jet1 = signal_df[signal_df.columns[signal_df.columns.str.contains("Jet1")]]
back_jet1 = background_df[background_df.columns[background_df.columns.str.contains("Jet1")]]
test_df = pd.concat([sig_jet1, back_jet1])

In [125]:
# Create Graph with Fully Connected Nodes

def fully_connected_graph(df):
    G = nx.Graph()

    nodes = list(df.columns)
    num_nodes = len(nodes)
    
    for node in nodes:
        G.add_node(node)
    
    # Connect all nodes
    for i, node1 in enumerate(nodes):
        for node2 in nodes[i+1:]:
            G.add_edge(node1, node2)
    
    for node in G.nodes(): 
        node_features = df[df['Jet0_FD_OWNPV'] == node].iloc[:, 2:].values
        if len(node_features) > 0:
            G.nodes[node]['x'] = torch.tensor(node_features[0], dtype=torch.float)
        else:
            G.nodes[node]['x'] = torch.tensor([0.0] * (df.shape[1] - 2), dtype=torch.float)
    
    data = from_networkx(G)
    data.y = torch.ones(data.num_nodes, dtype=torch.long)
    data.train_mask = torch.ones(data.num_nodes, dtype=torch.bool)
    data.test_mask = torch.ones(data.num_nodes, dtype=torch.bool)
    combined_graph = Data(
        x=data.x,
        edge_index=data.edge_index,
        y=data.y,
        train_mask=data.train_mask,
        test_mask=data.test_mask
    )
    return data

graph = fully_connected_graph(train_df)

In [102]:
graph

Data(x=[170, 168], edge_index=[2, 28730], y=[170], train_mask=[170], test_mask=[170])

In [115]:
# Define GNN model
class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = SAGEConv(in_channels=graph.x.shape[1], out_channels=16)
        self.conv2 = SAGEConv(in_channels=16, out_channels=2)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = GNN()

In [117]:
# Training parameters
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(10):
    loss = train(model, graph, optimizer, criterion)
    print(f'Epoch {epoch}, Loss: {loss}')

# Evaluation
def test(model, data):
    model.eval()
    _, pred = model(data).max(dim=1)
    correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
    acc = correct / data.test_mask.sum().item()
    return acc

accuracy = test(model, graph)
print(f'Accuracy: {accuracy}')

Epoch 0, Loss: 0.008448219858109951
Epoch 1, Loss: 0.0068959081545472145
Epoch 2, Loss: 0.005616004578769207
Epoch 3, Loss: 0.0045694452710449696
Epoch 4, Loss: 0.0037197950296103954
Epoch 5, Loss: 0.003033443819731474
Epoch 6, Loss: 0.002481241011992097
Epoch 7, Loss: 0.0020378308836370707
Epoch 8, Loss: 0.001681939116679132
Epoch 9, Loss: 0.0013962768716737628
Accuracy: 1.0


# ROC Curve