In [28]:
import numpy as np
from numba import jit
import h5py
import torch
from torch_geometric.data import Data, Dataset,

In [61]:
class EventDataset(Dataset):
    def __init__(self, transform=None, pre_transform=None):
        super(EventDataset, self).__init__(None, transform, pre_transform)
    
    @property
    def raw_file_names(self):
        return ['../data/external/wboson.txt', '../data/external/qstar.txt']
    
    @property
    def processed_file_names(self):
        return ['../data/processed/events.hdf5']
    
    def len(self):
        return 100000
    
    @jit(forceobj=True)
    def _get_edges(self, num_nodes):
        """Returns COO formatted graph edges for
        full connected graph of given number of nodes.
        type: (2, num_nodes * (num_nodes - 1)) dim array
        """
        nodes = np.arange(num_nodes, dtype=np.int64)
        edge_idx = np.vstack((
            np.repeat(nodes, num_nodes),
            np.repeat(nodes, num_nodes).reshape(-1, num_nodes).T.flatten()
        ))
        # removing self-loops
        mask = edge_idx[0] != edge_idx[1]
        edge_idx = edge_idx[:, mask]
        return edge_idx
    
    def get(self, idx):
        with h5py.File('../data/processed/events.hdf5', 'r') as f:
            # LOAD DATA:
            evt = f['wboson'][f'event_{idx:06}']
            num_nodes = evt.attrs['num_pcls']
            pmu = torch.from_numpy(evt['pmu'][...]) # 4-momentum for nodes
            edge_idx = torch.from_numpy(self._get_edges(num_nodes)).long()
            pdg = torch.from_numpy(evt['pdg'][...]) # PDG for posterity
            
            # CONSTRUCT EDGE LABELS:
            is_from_W = evt['is_from_W'][...]
            # node pairs bool labelled for all edges
            is_from_W = is_from_W[edge_idx]
            # reduce => True if both nodes True
            edge_labels = np.bitwise_and.reduce(is_from_W, axis=0)
            edge_labels = torch.from_numpy(edge_labels).float()
            
            # RETURN GRAPH
            return Data(x=pmu, edge_index=edge_idx,
                        y=edge_labels, pdg=pdg)

        

In [62]:
dataset = EventDataset()