In [1]:
# TODO: explore EM shower profile

import uproot
import torch

In [10]:
from torch_geometric.data import Data
from torch_geometric.data import InMemoryDataset
from graph_utils import load_tree


class HGCALTracksters(InMemoryDataset):
    
    def __init__(self, root, transform=None, pre_transform=None, pre_filter=None):
        super().__init__(root, transform, pre_transform, pre_filter)
        self.data, self.slices = torch.load(self.processed_paths[0])

    @property
    def raw_file_names(self):
        return ['tracksters_ds_pion.root']

    @property
    def processed_file_names(self):
        return ['data.pt']


    def process(self):

        filename = f"{self.root}/{self.raw_file_names[0]}"
        tracksters = uproot.open({filename: "tracksters"})

        dataset = []

        for g, label in load_tree(tracksters, N=2):
            x = torch.tensor([pos for _, pos in g.nodes("pos")])
            edge_index = torch.tensor(list(g.edges())).T
            y = torch.tensor(label)
            dataset.append(Data(x, edge_index=edge_index, y=y))

        data, slices = self.collate(dataset)
        torch.save((data, slices), self.processed_paths[0])

In [11]:
import torch_geometric.transforms as T
transform = T.Compose([T.NormalizeFeatures()])

ds = HGCALTracksters("data", transform=transform)

In [43]:
# balance the pions dataset
pos = ds[ds.data.y == 1]
neg = ds[ds.data.y == 0][:3660]
train_set = pos[:3500] + neg[:3500]
test_set = pos[3500:3660] + neg[3500:3660]

HGCALTracksters(3660)

In [53]:
print(f"PyTorch version: {torch.__version__}")
# Check PyTorch has access to MPS (Metal Performance Shader, Apple's GPU architecture)
# print(f"Is MPS (Metal Performance Shader) built? {torch.backends.mps.is_built()}")
# print(f"Is MPS available? {torch.backends.mps.is_available()}")
# device = torch.device('cuda' if torch.cuda.is_available() else 'mps' if torch.backends.mps.is_available() else "cpu")

device = torch.device('cuda' if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")

PyTorch version: 1.13.0.dev20220622
Using device: cpu


In [54]:
from torch_geometric.nn import GCNConv, global_mean_pool
from torch_geometric.loader import DataLoader

In [55]:
class TracksterClassifier(torch.nn.Module):
    def __init__(self, in_channels, out_channels, **kwargs):
        super(TracksterClassifier, self).__init__(**kwargs)

        self.conv1 = GCNConv(in_channels, out_channels)
        self.conv2 = GCNConv(out_channels, out_channels)
        self.dense = torch.nn.Linear(out_channels, 1)

    def forward(self, data):
        x, edge_index, batch = data.x, data.edge_index, data.batch
        x = self.conv1(x, edge_index)
        x = self.conv2(x, edge_index)
        x = global_mean_pool(x, batch)
        x = self.dense(x)
        return torch.sigmoid(x)

In [56]:
loss_obj = torch.nn.BCELoss()

def train(model, loader):
    epoch_loss = 0
    for batch in loader:
        model.train()
        batch = batch.to(device)
        optimizer.zero_grad()
        z = model(batch).reshape(-1)
        loss = loss_obj(z, batch.y.type(torch.float))
        epoch_loss += loss
        loss.backward()
        optimizer.step()
    return float(epoch_loss)

@torch.no_grad()
def test(model, data):
    total = 0
    correct = 0
    for batch in data:
        model.eval()
        prediction = (model(batch).reshape(-1) > 0.5).type(torch.int)
        total += len(prediction) 
        correct += sum(prediction == batch.y)
    return correct / total

In [58]:
model = TracksterClassifier(ds.num_node_features, 64)
model = model.to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
train_loader = DataLoader(train_set, batch_size=8, shuffle=True)
test_loader = DataLoader(test_set, batch_size=8, shuffle=True)

for epoch in range(101):
    loss = train(model, train_loader)
    train_acc = test(model, train_loader)
    test_acc = test(model, test_loader)
    if epoch % 5 == 0:
        print(f'Epoch: {epoch}, loss: {loss:.4f}, train acc: {train_acc:.4f}, test acc: {test_acc:.4f}')

Epoch: 0, loss: 600.7867, train acc: 0.5194, test acc: 0.5031
Epoch: 5, loss: 425.5844, train acc: 0.6564, test acc: 0.6406
Epoch: 10, loss: 407.8367, train acc: 0.7421, test acc: 0.7469
Epoch: 15, loss: 398.4371, train acc: 0.7501, test acc: 0.7437
Epoch: 20, loss: 387.6735, train acc: 0.7447, test acc: 0.7469
Epoch: 25, loss: 383.5847, train acc: 0.8509, test acc: 0.8625
Epoch: 30, loss: 369.4208, train acc: 0.8121, test acc: 0.8125
Epoch: 35, loss: 366.3902, train acc: 0.8474, test acc: 0.8469
Epoch: 40, loss: 359.6199, train acc: 0.8513, test acc: 0.8500
Epoch: 45, loss: 360.0508, train acc: 0.8046, test acc: 0.8125
Epoch: 50, loss: 353.3960, train acc: 0.8633, test acc: 0.8813
Epoch: 55, loss: 356.5569, train acc: 0.8219, test acc: 0.8406
Epoch: 60, loss: 354.2986, train acc: 0.8574, test acc: 0.8656
Epoch: 65, loss: 352.4052, train acc: 0.8716, test acc: 0.8750
Epoch: 70, loss: 351.9672, train acc: 0.8671, test acc: 0.8813
Epoch: 75, loss: 348.5856, train acc: 0.8661, test acc: 0