In [2]:
import pandas as pd
import networkx as nx
import torch
import torch.nn.functional as F
import torch.optim as optim
from torch_geometric.data import Data
from torch_geometric.utils import from_networkx
from torch_geometric.nn import GCNConv

# Load data
signal_df = pd.read_csv('Dijet_bb_pt10_15_dw.csv')
background_df = pd.read_csv('Dijet_qq_pt10_15_dw.csv')


# Data is loaded, now we want a function to create the graph objects from the dataframes.  This is the trickiest part, because here we need to define all of our nodes, edges, and features.

In [7]:
# Create graph from DataFrame
def create_graph_from_df(df):
    G = nx.Graph()
    for i, row in df.iterrows():
        G.add_edge(row['Jet0_PT'], row['Jet0_Eta']) # We need to define what oure edges are
    for node in G.nodes(): 
        node_features = df[df['Jet0_FD_OWNPV'] == node].iloc[:, 2:].values
        if len(node_features) > 0:
            G.nodes[node]['x'] = torch.tensor(node_features[0], dtype=torch.float)
        else:
            G.nodes[node]['x'] = torch.tensor([0.0] * (df.shape[1] - 2), dtype=torch.float)
    data = from_networkx(G)
    return data

signal_graph = create_graph_from_df(signal_df)
background_graph = create_graph_from_df(background_df)

In [8]:
# Define GNN model
class GNN(torch.nn.Module):
    def __init__(self):
        super(GNN, self).__init__()
        self.conv1 = GCNConv(in_channels=signal_graph.num_node_features, out_channels=16)
        self.conv2 = GCNConv(in_channels=16, out_channels=2)
    
    def forward(self, data):
        x, edge_index = data.x, data.edge_index
        x = self.conv1(x, edge_index)
        x = F.relu(x)
        x = self.conv2(x, edge_index)
        return F.log_softmax(x, dim=1)

model = GNN()

# After we have the model, we need to combine the two datasets` graph to use in the training.

In [9]:
# Combine signal and background graphs
def create_combined_dataset(signal_graph, background_graph):
    signal_graph.y = torch.ones(signal_graph.num_nodes, dtype=torch.long)
    background_graph.y = torch.zeros(background_graph.num_nodes, dtype=torch.long)
    signal_graph.train_mask = torch.ones(signal_graph.num_nodes, dtype=torch.bool)
    background_graph.train_mask = torch.ones(background_graph.num_nodes, dtype=torch.bool)
    signal_graph.test_mask = torch.ones(signal_graph.num_nodes, dtype=torch.bool)
    background_graph.test_mask = torch.ones(background_graph.num_nodes, dtype=torch.bool)
    combined_graph = Data(
        x=torch.cat([signal_graph.x, background_graph.x], dim=0),
        edge_index=torch.cat([signal_graph.edge_index, background_graph.edge_index + signal_graph.num_nodes], dim=1),
        y=torch.cat([signal_graph.y, background_graph.y], dim=0),
        train_mask=torch.cat([signal_graph.train_mask, background_graph.train_mask], dim=0),
        test_mask=torch.cat([signal_graph.test_mask, background_graph.test_mask], dim=0),
    )
    return combined_graph

combined_graph = create_combined_dataset(signal_graph, background_graph)

# Then we can do the training and testing, evaluating the ability at the end.

In [10]:
# Training parameters
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.01)

# Training loop
def train(model, data, optimizer, criterion):
    model.train()
    optimizer.zero_grad()
    out = model(data)
    loss = criterion(out[data.train_mask], data.y[data.train_mask])
    loss.backward()
    optimizer.step()
    return loss.item()

for epoch in range(200):
    loss = train(model, combined_graph, optimizer, criterion)
    print(f'Epoch {epoch}, Loss: {loss}')

# Evaluation
def test(model, data):
    model.eval()
    _, pred = model(data).max(dim=1)
    correct = pred[data.test_mask].eq(data.y[data.test_mask]).sum().item()
    acc = correct / data.test_mask.sum().item()
    return acc

accuracy = test(model, combined_graph)
print(f'Accuracy: {accuracy}')

Epoch 0, Loss: 0.6931473612785339
Epoch 1, Loss: 0.6931905150413513
Epoch 2, Loss: 0.6931498646736145
Epoch 3, Loss: 0.6931565403938293
Epoch 4, Loss: 0.6931718587875366
Epoch 5, Loss: 0.6931644678115845
Epoch 6, Loss: 0.6931509375572205
Epoch 7, Loss: 0.6931474208831787
Epoch 8, Loss: 0.6931541562080383
Epoch 9, Loss: 0.6931596994400024
Epoch 10, Loss: 0.6931570768356323
Epoch 11, Loss: 0.6931506991386414
Epoch 12, Loss: 0.6931470036506653
Epoch 13, Loss: 0.6931485533714294
Epoch 14, Loss: 0.6931524276733398
Epoch 15, Loss: 0.693153977394104
Epoch 16, Loss: 0.6931514143943787
Epoch 17, Loss: 0.6931478381156921
Epoch 18, Loss: 0.6931470632553101
Epoch 19, Loss: 0.6931487321853638
Epoch 20, Loss: 0.693150520324707
Epoch 21, Loss: 0.6931505799293518
Epoch 22, Loss: 0.6931488513946533
Epoch 23, Loss: 0.6931472420692444
Epoch 24, Loss: 0.6931472420692444
Epoch 25, Loss: 0.6931483149528503
Epoch 26, Loss: 0.6931489706039429
Epoch 27, Loss: 0.6931485533714294
Epoch 28, Loss: 0.69314742088317