# Defining the interaction layer

In [1]:
import torch
from torch_geometric.nn import MessagePassing

In [2]:
class Interaction(MessagePassing):
    def __init__(self, in_edge, in_node, out_edge, out_node):
        super(Interaction, self).__init__(
            aggr='add',
            flow="source_to_target")
        self.in_edge = 2 * in_node + in_edge
        self.in_node = in_node + out_edge
        self.mlp_edge = torch.nn.Sequential(
            torch.nn.Linear(self.in_edge, out_edge, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(out_edge, out_edge, bias=True)
        )
        self.mlp_node = torch.nn.Sequential(
            torch.nn.Linear(self.in_node, out_node, bias=True),
            torch.nn.ReLU(),
            torch.nn.Linear(out_node, out_node, bias=True)
        )
        
    def forward(self, x, edge_index, edge_attrs):
        return self.propagate(
            x=x,
            edge_index=edge_index,
            edge_attrs=edge_attrs
        )
    
    def message(self, x_i, x_j, edge_index, edge_attrs):
        recv_send = [x_i, x_j]
        if edge_attrs is not None:
            recv_send.append(edge_attrs)
        recv_send = torch.cat(recv_send, dim=1)
        self.edge_embed = self.mlp_edge(recv_send)
        return self.edge_embed
    
    def update(self, aggr_out, x):
        node_embed = self.mlp_node(torch.cat([x, aggr_out], dim=1))
        return (self.edge_embed, node_embed)

# Defining the network

In [3]:
class Net(torch.nn.Module):
    def __init__(self, dim_node=4, dim_edge=0, dim_embed_edge=4, dim_embed_node=16):
        super(Net, self).__init__()
        self.conv1 = Interaction(dim_edge, dim_node,
                                 dim_embed_edge, dim_embed_node)
        self.conv2 = Interaction(dim_embed_edge, dim_embed_node,
                                 dim_embed_edge, dim_embed_node)
        self.conv3 = Interaction(dim_embed_edge, dim_embed_node,
                                 dim_embed_edge, dim_embed_node)
        self.conv4 = Interaction(dim_embed_edge, dim_embed_node,
                                 dim_embed_edge, dim_embed_node)
        self.classify = torch.nn.Sequential(
            torch.nn.Linear(dim_embed_edge, 1),
            torch.nn.Sigmoid()
        )
        
    def forward(self, data):
        node_attrs = data.x
        edge_attrs = data.edge_attr
        
        edge_attrs, node_attrs = self.conv1(node_attrs, data.edge_index, edge_attrs)
        edge_attrs, node_attrs = self.conv2(node_attrs, data.edge_index, edge_attrs)
        edge_attrs, node_attrs = self.conv3(node_attrs, data.edge_index, edge_attrs)
        edge_attrs, node_attrs = self.conv4(node_attrs, data.edge_index, edge_attrs)
        
        return self.classify(edge_attrs)

# Create automated dataloaders

In [6]:
import numpy as np
from numba import jit
import h5py
from torch_geometric.data import Data, Dataset, DataLoader

In [7]:
class EventDataset(Dataset):
    def __init__(self, transform=None, pre_transform=None):
        super(EventDataset, self).__init__(None, transform, pre_transform)
        with h5py.File('../data/processed/events.hdf5', 'r') as f:
            self.len = f['wboson'].attrs['num_evts']
        
    @property
    def raw_file_names(self):
        return ['../data/external/wboson.txt', '../data/external/qstar.txt']
    
    @property
    def processed_file_names(self):
        return ['../data/processed/events.hdf5']
    
    def len(self):
        return self.len
    
    @jit(forceobj=True)
    def _get_edges(self, num_nodes):
        """Returns COO formatted graph edges for
        full connected graph of given number of nodes.
        type: (2, num_nodes * (num_nodes - 1)) dim array
        """
        nodes = np.arange(num_nodes, dtype=np.int64)
        edge_idx = np.vstack((
            np.repeat(nodes, num_nodes),
            np.repeat(nodes, num_nodes).reshape(-1, num_nodes).T.flatten()
        ))
        # removing self-loops
        mask = edge_idx[0] != edge_idx[1]
        edge_idx = edge_idx[:, mask]
        return edge_idx
    
    def get(self, idx):
        with h5py.File('../data/processed/events.hdf5', 'r') as f:
            # LOAD DATA:
            evt = f['wboson'][f'event_{idx:06}']
            num_nodes = evt.attrs['num_pcls']
            pmu = torch.from_numpy(evt['pmu'][...]) # 4-momentum for nodes
            edge_idx = torch.from_numpy(self._get_edges(num_nodes)).long()
            pdg = torch.from_numpy(evt['pdg'][...]) # PDG for posterity
            
            # CONSTRUCT EDGE LABELS:
            is_from_W = evt['is_from_W'][...]
            # node pairs bool labelled for all edges
            is_from_W = is_from_W[edge_idx]
            # reduce => True if both nodes True
            edge_labels = np.bitwise_and.reduce(is_from_W, axis=0)
            edge_labels = torch.from_numpy(edge_labels).float()
            
            # RETURN GRAPH
            return Data(x=pmu, edge_index=edge_idx,
                        y=edge_labels, pdg=pdg)

# Train the network

In [32]:
from tqdm import tqdm

In [33]:
model = Net()
optimiser = torch.optim.Adam(model.parameters(), lr=0.01)
loss_fn = torch.nn.BCELoss()
dataset = EventDataset()

In [35]:
for evt_num in tqdm(range(int(0.1 * dataset.len))):
    data = dataset.get(evt_num)
    optimiser.zero_grad()
    edge_pred = model(data)
    loss = loss_fn(edge_pred.squeeze(1), data.y)
    loss.backward()
    optimiser.step()

  1%|          | 65/10000 [03:03<7:47:59,  2.83s/it] 


KeyboardInterrupt: 