In [7]:
import torch
from utils.data_utils import HGCALTracksters
import torch_geometric.transforms as T
from torch_geometric.loader import DataLoader

In [6]:
transform = T.Compose([T.NormalizeFeatures()])
ds = HGCALTracksters("data", kind="photon", transform=transform)

Processing...


Processing: tracksters_ds_10e.root
Processing: tracksters_ds_pion.root


Done!


In [None]:
# prepare tuples
# (complete, complete) = 0
# (incomplete, complete) = 0

In [12]:
# balance the dataset
pos = ds[ds.data.y == 1]
neg = ds[ds.data.y == 0]
len_neg = len(neg)
len_pos = len(pos)
shorter = min(len_neg, len_pos)

In [15]:
train_set = pos[:shorter - 100] + neg[:shorter - 100]
test_set = pos[shorter - 100:shorter] + neg[shorter - 100:shorter]
train_dl = DataLoader(train_set, batch_size=8, shuffle=True)
test_dl = DataLoader(test_set, batch_size=8, shuffle=True)

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

In [None]:
from torch_geometric.nn import GCNConv, global_mean_pool


class TracksterMerger(torch.nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(TracksterMerger, self).__init__(**kwargs)

        self.conv1 = GCNConv(in_channels, out_channels)
        self.conv2 = GCNConv(out_channels, out_channels)
        self.dense = torch.nn.Linear(out_channels, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.dense(x)
        return torch.sigmoid(x)