# End-to-End Architecture Performance Comparisons

In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
# System imports
import os
import sys
from time import time as tt
import importlib

# External imports
import matplotlib.pyplot as plt
import scipy as sp
import numpy as np
import pandas as pd
import seaborn as sns
import torch
from torch_geometric.data import DataLoader

from itertools import chain
from random import shuffle, sample
from scipy.optimize import root_scalar as root

from torch.nn import Linear
import torch.nn.functional as F
from torch_cluster import knn_graph, radius_graph
import trackml.dataset
import torch_geometric
from itertools import permutations
import itertools
from sklearn import metrics
import pytorch_lightning as pl
from pytorch_lightning import LightningModule, Trainer, LightningDataModule
from pytorch_lightning.loggers import WandbLogger
from torch.utils.checkpoint import checkpoint

from argparse import Namespace
from trackml.score import score_event
from sklearn.cluster import DBSCAN

# Limit CPU usage on Jupyter
os.environ['OMP_NUM_THREADS'] = '4'

# Pick up local packages
sys.path.append('..')
sys.path.append('../../Tracking-ML-Exa.TrkX/src/Pipelines/Examples')
from LightningModules.GNN.Models.checkpoint_agnn import CheckpointedResAGNN
# from LightningModules.Embedding.utils import get_best_run, build_edges, res, graph_intersection

# Local imports
from prepare_utils import *
from performance_utils import *
from toy_utils import *
from models import *
from trainers import *
from lightning_modules.filter_scanner import Filter_Model
%matplotlib inline


# Get rid of RuntimeWarnings, gross
import warnings
warnings.filterwarnings("ignore", category=RuntimeWarning)
device = 'cuda' if torch.cuda.is_available() else 'cpu'
import wandb
import faiss
props = dict(boxstyle='round', facecolor='wheat', alpha=0.5)
torch_seed = 0

## Data Preparation

### Lightning Load Filter

In [2]:
run_label = "akfxzlsc"

best_run_path = get_best_run(run_label, wandb_save_dir="/global/cscratch1/sd/danieltm/test_20200930/wandb_data")

chkpnt = torch.load(best_run_path)
hparams = chkpnt['hyper_parameters']
model = Filter_Model.load_from_checkpoint(best_run_path)
model = model.to(device)

### Prepare GNN Candidates

In [3]:
pt_cut = 0
train_number = 1000
test_number = 100
load_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/filter_processed/"
save_dir = "/global/cscratch1/sd/danieltm/test_20200930/trackml_processed/graphs_processed/0_pt_cut/all_events"

basename = os.path.join(load_dir, str(pt_cut) + "_pt_cut")
train_path = os.path.join(basename, str(train_number) + "_events_train_cell_info.pkl")
test_path = os.path.join(basename, str(test_number) + "_events_test_cell_info.pkl")

In [4]:
%%time
train_dataset = torch.load(train_path)
test_dataset = torch.load(test_path)
train_loader = DataLoader(train_dataset, batch_size=1, shuffle=True)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=True)

CPU times: user 1.69 s, sys: 43 s, total: 44.6 s
Wall time: 50.5 s


In [5]:
sections = 4

In [6]:
%%time
model.eval()
with torch.no_grad():
    for i, batch in enumerate(train_dataset):
        tic = tt()
        if not os.path.exists(os.path.join(save_dir, batch.event_file[-4:])):
            data = batch.to(device)
            emb = (None if (hparams["emb_channels"] == 0) 
                       else data.embedding) 

            cut_list = []
            for j in range(sections):
                subset_ind = torch.chunk(torch.arange(data.e_radius.shape[1]), sections)[j]
#                 print(subset_ind)
                output = model(torch.cat([data.cell_data, data.x], axis=-1), data.e_radius[:, subset_ind], emb).squeeze() if ('ci' in hparams["regime"]) else model(data.x, data.e_radius[:, subset_ind], emb).squeeze()
                cut = F.sigmoid(output) > 0.35
                cut_list.append(cut)

            cut_list = torch.cat(cut_list)
            batch.edge_index = batch.e_radius[:, cut_list]
            batch.e_radius = None
            batch.embedding = None
    #         batch.x = batch.x.cpu()
    #         batch.y = torch.from_numpy(y[combined_indices]).float()
            batch.y = batch.y[cut_list]

            with open(os.path.join(save_dir, batch.event_file[-4:]), 'wb') as pickle_file:
                torch.save(batch, pickle_file)

        print(i, "saved in time", tt()-tic)
#         break


0 saved in time 0.001600503921508789
1 saved in time 0.00089263916015625
2 saved in time 0.0006067752838134766
3 saved in time 0.0007889270782470703
4 saved in time 0.0006837844848632812
5 saved in time 0.0007073879241943359
6 saved in time 0.0008025169372558594
7 saved in time 0.0006792545318603516
8 saved in time 0.0005385875701904297
9 saved in time 0.0006489753723144531
10 saved in time 0.0008313655853271484
11 saved in time 0.0006642341613769531
12 saved in time 0.0007109642028808594
13 saved in time 0.0005555152893066406
14 saved in time 0.0008051395416259766
15 saved in time 0.0006551742553710938
16 saved in time 0.0006959438323974609
17 saved in time 0.0005574226379394531
18 saved in time 0.0007450580596923828
19 saved in time 0.0008175373077392578
20 saved in time 0.0004887580871582031
21 saved in time 0.0006873607635498047
22 saved in time 0.0006415843963623047
23 saved in time 0.0005433559417724609
24 saved in time 0.0006670951843261719
25 saved in time 0.0007145404815673828

269 saved in time 0.0006635189056396484
270 saved in time 0.0006654262542724609
271 saved in time 0.0005879402160644531
272 saved in time 0.0006017684936523438
273 saved in time 0.0007472038269042969
274 saved in time 0.0005445480346679688
275 saved in time 0.0006933212280273438
276 saved in time 0.0007627010345458984
277 saved in time 0.0006377696990966797
278 saved in time 0.0005397796630859375
279 saved in time 0.0005831718444824219
280 saved in time 0.0005257129669189453
281 saved in time 0.0007479190826416016
282 saved in time 0.0005679130554199219
283 saved in time 0.0006432533264160156
284 saved in time 0.0005216598510742188
285 saved in time 0.0005011558532714844
286 saved in time 0.0006444454193115234
287 saved in time 0.0007166862487792969
288 saved in time 0.0005717277526855469
289 saved in time 0.0006115436553955078
290 saved in time 0.0007181167602539062
291 saved in time 0.0006628036499023438
292 saved in time 0.0004987716674804688
293 saved in time 0.0007331371307373047


579 saved in time 0.0007283687591552734
580 saved in time 0.00061798095703125
581 saved in time 0.0007302761077880859
582 saved in time 0.0005774497985839844
583 saved in time 0.000484466552734375
584 saved in time 0.0005459785461425781
585 saved in time 0.0004937648773193359
586 saved in time 0.000568389892578125
587 saved in time 0.0006349086761474609
588 saved in time 0.0004875659942626953
589 saved in time 0.0006122589111328125
590 saved in time 0.0005452632904052734
591 saved in time 0.0005803108215332031
592 saved in time 0.0005762577056884766
593 saved in time 0.0006225109100341797
594 saved in time 0.0005397796630859375
595 saved in time 0.0005300045013427734
596 saved in time 0.0005548000335693359
597 saved in time 0.0005755424499511719
598 saved in time 0.0005104541778564453
599 saved in time 0.0004775524139404297
600 saved in time 0.0005898475646972656
601 saved in time 0.0005791187286376953
602 saved in time 0.0005290508270263672
603 saved in time 0.0005464553833007812
604 

901 saved in time 0.0005047321319580078
902 saved in time 0.00043845176696777344
903 saved in time 0.0006113052368164062
904 saved in time 0.0004611015319824219
905 saved in time 0.0005676746368408203
906 saved in time 0.0005767345428466797
907 saved in time 0.0006110668182373047
908 saved in time 0.0004892349243164062
909 saved in time 0.0006556510925292969
910 saved in time 0.0004911422729492188
911 saved in time 0.0004258155822753906
912 saved in time 0.0006051063537597656
913 saved in time 0.0004315376281738281
914 saved in time 0.0004937648773193359
915 saved in time 0.0005159378051757812
916 saved in time 0.000453948974609375
917 saved in time 0.0005321502685546875
918 saved in time 0.0005867481231689453
919 saved in time 0.0006082057952880859
920 saved in time 0.0007765293121337891
921 saved in time 0.0004868507385253906
922 saved in time 0.0016207695007324219
923 saved in time 0.0007171630859375
924 saved in time 0.0005331039428710938
925 saved in time 0.0006148815155029297
926

### Load & Save Raw Data

Load events:

In [16]:
pt_cut = 0
train_number = 1000
test_number = 100

In [17]:
save_dir = "/global/cscratch1/sd/danieltm/test_20200930/trackml_processed/graphs_processed/"
# basename = os.path.join(save_dir, str(pt_cut) + "_pt_cut_endcaps")
basename = os.path.join(save_dir, str(pt_cut) + "_pt_cut")
load_path = os.path.join(basename, "all_events")
all_events = os.listdir(load_path)
all_events = sorted([os.path.join(load_path, event) for event in all_events])

In [18]:
%%time
train_dataset = [torch.load(event, map_location="cpu") for event in all_events[:train_number]]
test_dataset = [torch.load(event, map_location="cpu") for event in all_events[-test_number:]]

CPU times: user 2.47 s, sys: 5.14 s, total: 7.6 s
Wall time: 8.08 s


In [19]:
%%time
with open(os.path.join(basename, str(train_number) + "_events_train.pkl"), 'wb') as pickle_file:
#     pickle.dump(train_dataset, pickle_file)
    torch.save(train_dataset, pickle_file)
with open(os.path.join(basename, str(test_number) + "_events_test.pkl"), 'wb') as pickle_file:
#     pickle.dump(train_dataset, pickle_file)
    torch.save(test_dataset, pickle_file)

CPU times: user 221 ms, sys: 7.52 s, total: 7.74 s
Wall time: 7.87 s


### Load Scrubbed Filter-ready Events

No performance gain from prebuilt torch pickle

In [3]:
with open("../configs/dev_gnn.yaml") as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

In [4]:
def load_dataset(input_dir, num):
    all_events = os.listdir(input_dir)
    all_events = sorted([os.path.join(input_dir, event) for event in all_events])
    loaded_events = [torch.load(event, map_location=torch.device('cpu')) for event in all_events[:num]]

    return loaded_events

In [5]:
class GraphDataModule(LightningDataModule):

    def __init__(self, hparams):
        super().__init__()
        
        # Assign hyperparameters
        self.hparams = hparams
        
    def prepare_data(self):
        # called only on 1 GPU
        datatypes = ["train", "val", "test"]
        input_dirs = [os.path.join(self.hparams["input_dir"], datatype) for datatype in datatypes]
        self.trainset, self.valset, self.testset = [load_dataset(input_dir, self.hparams["train_split"][i]) for i, input_dir in enumerate(input_dirs)]

    def train_dataloader(self):
        if len(self.trainset) > 0:
            return DataLoader(self.trainset, batch_size=1, num_workers=1)
        else:
            return None

    def val_dataloader(self):
        if len(self.valset) > 0:
            return DataLoader(self.valset, batch_size=1, num_workers=1)
        else:
            return None

    def test_dataloader(self):
        if len(self.testset) > 0:
            return DataLoader(self.testset, batch_size=1, num_workers=1)
        else:
            return None
        
class SplitGraphDataModule(GraphDataModule):
    
    def __init__(self, hparams):
        super().__init__(hparams)
        
    def prepare_data(self):
        # called only on 1 GPU
        datatypes = ["train_split", "val", "test"]
        input_dirs = [os.path.join(self.hparams["input_dir"], datatype) for datatype in datatypes]
        self.trainset, self.valset, self.testset = [load_dataset(input_dir, self.hparams["train_split"][i]) for i, input_dir in enumerate(input_dirs)]

In [6]:
%%time
graph_dm = GraphDataModule(config)
# graph_dm = SplitGraphDataModule(config)
graph_dm.prepare_data()

CPU times: user 53.2 ms, sys: 221 ms, total: 274 ms
Wall time: 493 ms


In [7]:
sample = graph_dm.train_dataloader().dataset[0]

## Sampling Experimentation

In [9]:
def random_edge_slice(delta_phi, batch):
    # 1. Select random phi
    random_phi = np.random.rand()*2 - 1
    e = batch.edge_index.to('cpu')
    x = batch.x.to('cpu')

    # 2. Find hits within delta_phi of random_phi
    dif = abs(x[:,1] - random_phi)
    subset_hits = np.where((dif < delta_phi) | ((2-dif) < delta_phi))[0]

    # 3. Filter edges with subset_hits
    subset_edges_ind = (np.isin(e[0], subset_hits) | np.isin(e[1], subset_hits))

    subset_hits = np.unique(e[:, subset_edges_ind])
    subset_edges_extended = (np.isin(e[0], subset_hits) | np.isin(e[1], subset_hits))
    nested_ind = np.isin(np.where(subset_edges_extended)[0], np.where(subset_edges_ind)[0])
    
    return subset_edges_ind, subset_edges_extended, nested_ind

In [249]:
def random_edge_slice_v2(delta_phi, batch):
    # 1. Select random phi
    random_phi = np.random.rand()*2 - 1
    e = batch.edge_index.to('cpu').numpy()
    x = batch.x.to('cpu')

    # 2. Find edges within delta_phi of random_phi
    e_average = (x[e[0], 1] + x[e[1], 1])/2
    dif = abs(e_average - random_phi)
    subset_edges = ((dif < delta_phi) | ((2-dif) < delta_phi)).numpy()

    # 3. Find connected edges to this subset   
    e_ones = cp.array([1]*e_length).astype('Float32')
    subset_ones = cp.array([1]*subset_edges.sum()).astype('Float32')

    e_csr_in = cp.sparse.coo_matrix((e_ones, (cp.array(e[0]).astype('Float32'), cp.arange(e_length).astype('Float32'))), shape=(e.max()+1,e_length)).tocsr()
    e_csr_out = cp.sparse.coo_matrix((e_ones, (cp.array(e[0]).astype('Float32'), cp.arange(e_length).astype('Float32'))), shape=(e.max()+1,e_length)).tocsr()
    e_csr = e_csr_in + e_csr_out

    subset_csr_in = cp.sparse.coo_matrix((subset_ones, (cp.array(e[0, subset_edges]).astype('Float32'), cp.arange(e_length)[subset_edges].astype('Float32'))), shape=(e.max()+1,e_length)).tocsr()
    subset_csr_out = cp.sparse.coo_matrix((subset_ones, (cp.array(e[0, subset_edges]).astype('Float32'), cp.arange(e_length)[subset_edges].astype('Float32'))), shape=(e.max()+1,e_length)).tocsr()
    subset_csr = subset_csr_in + subset_csr_out

    summed = (subset_csr.T * e_csr).sum(axis=0)
    subset_edges_extended = (summed>0)[0].get()
    
    return subset_edges, subset_edges_extended

In [17]:
def split_edges(phi_sections, batch):
    # 1. Loop over 0 -> dPhi -> 1
    delta_phi =  1/phi_sections
    for phi in np.arange(0, 1, delta_phi):
        random_phi = np.random.rand()*2 - 1
        e = batch.edge_index.to('cpu')
        x = batch.x.to('cpu')

        # 2. Find hits within delta_phi of random_phi
        dif = abs(x[:,1] - random_phi)
        subset_hits = np.where((dif < delta_phi) | ((2-dif) < delta_phi))[0]

        # 3. Filter edges with subset_hits
        subset_edges_ind = (np.isin(e[0], subset_hits) | np.isin(e[1], subset_hits))

        subset_hits = np.unique(e[:, subset_edges_ind])
        subset_edges_extended = (np.isin(e[0], subset_hits) | np.isin(e[1], subset_hits))
        nested_ind = np.isin(np.where(subset_edges_extended)[0], np.where(subset_edges_ind)[0])
    
    return subset_edges_ind, subset_edges_extended, nested_ind

In [11]:
batch = graph_dm.val_dataloader().dataset[0]

In [257]:
%%time
subsets = random_edge_slice(0.1, batch)

CPU times: user 453 ms, sys: 441 µs, total: 453 ms
Wall time: 424 ms


In [258]:
%%time
subsets = random_edge_slice_v2(0.1, batch)

CPU times: user 234 ms, sys: 28.8 ms, total: 263 ms
Wall time: 235 ms


  del sys.path[0]
  
  app.launch_new_instance()


In [256]:
%%time
split_edges(12, batch)

CPU times: user 5.19 s, sys: 281 µs, total: 5.19 s
Wall time: 4.87 s


(array([False,  True, False, ..., False, False, False]),
 array([ True,  True,  True, ..., False, False,  True]),
 array([False,  True, False, ..., False, False, False]))

## GNN Training

### Models

In [44]:
class GNNBase(LightningModule):

    def __init__(self, hparams):
        super().__init__()
        '''
        Initialise the Lightning Module that can scan over different GNN training regimes
        '''
        # Assign hyperparameters
        self.hparams = hparams

    def configure_optimizers(self):
        optimizer = [torch.optim.AdamW(self.parameters(), lr=(self.hparams["lr"]), betas=(0.9, 0.999), eps=1e-08, amsgrad=True)]
        scheduler = [
            {
                'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer[0], factor=self.hparams["factor"], patience=self.hparams["patience"]),
                'monitor': 'checkpoint_on',
                'interval': 'epoch',
                'frequency': 1
            }
        ]
#         scheduler = [torch.optim.lr_scheduler.StepLR(optimizer[0], step_size=1, gamma=0.3)]
        return optimizer, scheduler

    def training_step(self, batch, batch_idx):
        
        weight = (torch.tensor(self.hparams["weight"]) if ("weight" in self.hparams)
                      else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()))

        output = (self(torch.cat([batch.cell_data, batch.x], axis=-1), 
                       batch.edge_index).squeeze()
                  if ('ci' in self.hparams["regime"])
                  else self(batch.x, batch.edge_index).squeeze())

        
        if ('pid' in self.hparams["regime"]):
            y_pid = (batch.pid[batch.edge_index[0, batch.nested_ind[0]]] == batch.pid[batch.edge_index[1, batch.nested_ind[0]]]).float()
            loss = F.binary_cross_entropy_with_logits(output[batch.nested_ind[0]], y_pid.float(), pos_weight = weight)
        else:
            loss = F.binary_cross_entropy_with_logits(output[batch.nested_ind[0]], batch.y[batch.nested_ind[0]], pos_weight = weight)
            
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, prog_bar=True)

        return result

    def validation_step(self, batch, batch_idx):

        weight = (torch.tensor(self.hparams["weight"]) if ("weight" in self.hparams)
                      else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()))

        output = (self(torch.cat([batch.cell_data, batch.x], axis=-1), batch.edge_index).squeeze()
                  if ('ci' in self.hparams["regime"])
                  else self(batch.x, batch.edge_index).squeeze())

        if ('pid' in self.hparams["regime"]):
            y_pid = (batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]).float()
            val_loss = F.binary_cross_entropy_with_logits(output, y_pid.float(), pos_weight = weight)
        else:
            val_loss = F.binary_cross_entropy_with_logits(output, batch.y, pos_weight = weight)

        result = pl.EvalResult(checkpoint_on=val_loss)
        result.log('val_loss', val_loss)

        #Edge filter performance
        preds = F.sigmoid(output) > 0.5 #Maybe send to CPU??
        edge_positive = preds.sum().float()

        if ('pid' in self.hparams["regime"]):
            y_pid = batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]
            edge_true = y_pid.sum().float()
            edge_true_positive = (y_pid & preds).sum().float()
        else:
            edge_true = batch.y.sum()
            edge_true_positive = (batch.y.bool() & preds).sum().float()

        result.log_dict({'eff': torch.tensor(edge_true_positive/edge_true), 'pur': torch.tensor(edge_true_positive/edge_positive)})

        return result

    def optimizer_step(self, current_epoch, batch_nb, optimizer, optimizer_idx, second_order_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
        # warm up lr
        if (self.hparams["warmup"] is not None) and (self.trainer.global_step < self.hparams["warmup"]):
            lr_scale = min(1., float(self.trainer.global_step + 1) / self.hparams["warmup"])
            for pg in optimizer.param_groups:
                pg['lr'] = lr_scale * self.hparams["lr"]

        # update params
        optimizer.step()
        optimizer.zero_grad()
        

class InteractionGNN(GNNBase):
    
    def __init__(self, hparams):
        super().__init__(hparams)
        '''
        Initialise the Lightning Module that can scan over different GNN training regimes
        '''

        # Setup input network
        self.node_encoder = make_mlp(hparams["in_channels"], [hparams["hidden"]],
                                      output_activation=hparams["hidden_activation"],
                                      layer_norm=hparams["layernorm"])
   
        # The edge network computes new edge features from connected nodes
        self.edge_encoder = make_mlp(2*(hparams["hidden"]),
                                     [hparams["hidden"]]*hparams["nb_edge_layer"],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])

        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(4*hparams["hidden"],
                                     [hparams["hidden"]]*hparams["nb_edge_layer"],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])

        # The node network computes new node features
        self.node_network = make_mlp(4*hparams["hidden"],
                                     [hparams["hidden"]]*hparams["nb_node_layer"],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])
        
        # Final edge output classification network
        self.output_edge_classifier = make_mlp(3*hparams["hidden"],
                                     [hparams["hidden"], 1],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])
        

    def forward(self, x, edge_index):
        
        start, end = edge_index
        
        # Encode the graph features into the hidden space
        x = self.node_encoder(x)
        e = self.edge_encoder(torch.cat([x[start], x[end]], dim=1))
        input_x = x
        input_e = e      

        edge_outputs = []
        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            
            # Cocnatenate with initial latent space
            x = torch.cat([x, input_x], dim=-1)
            e = torch.cat([e, input_e], dim=-1)

            # Compute new node features
            edge_messages = scatter_add(e, end, dim=0, dim_size=x.shape[0]) + scatter_add(e, start, dim=0, dim_size=x.shape[0])
            node_inputs = torch.cat([x, edge_messages], dim=-1)
            x = self.node_network(node_inputs)
          
            # Compute new edge features            
            edge_inputs = torch.cat([x[start], x[end], e], dim=-1)
            e = self.edge_network(edge_inputs)
            e = torch.sigmoid(e)
            
            classifier_inputs = torch.cat([x[start], x[end], e], dim=1)
            edge_outputs.append(self.output_edge_classifier(classifier_inputs).squeeze(-1))
    
        # Compute final edge scores; use original edge directions only        
        
        return torch.cat(edge_outputs)
#         classifier_inputs = torch.cat([x[start], x[end], e], dim=1)
#         return self.output_edge_classifier(classifier_inputs).squeeze(-1)

    def training_step(self, batch, batch_idx):
        
        weight = (torch.tensor(self.hparams["weight"]) if ("weight" in self.hparams)
                      else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()))
        
        output = (self(torch.cat([batch.cell_data, batch.x], axis=-1), 
                       batch.edge_index).squeeze()
                  if ('ci' in self.hparams["regime"])
                  else self(batch.x, batch.edge_index).squeeze())

        if ('pid' in self.hparams["regime"]):
            y_pid = (batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]).float()
            y_pid = y_pid.repeat((self.hparams["n_graph_iters"]))
            loss = F.binary_cross_entropy_with_logits(output, y_pid.float(), pos_weight = weight)
        else:
            y = batch.y.repeat((self.hparams["n_graph_iters"]))
            loss = F.binary_cross_entropy_with_logits(output, y, pos_weight = weight)
            
        result = pl.TrainResult(minimize=loss)
        result.log('train_loss', loss, prog_bar=True)

        return result

    
    def validation_step(self, batch, batch_idx):

        weight = (torch.tensor(self.hparams["weight"]) if ("weight" in self.hparams)
                      else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()))

        output = (self(torch.cat([batch.cell_data, batch.x], axis=-1), batch.edge_index).squeeze()
                  if ('ci' in self.hparams["regime"])
                  else self(batch.x, batch.edge_index).squeeze())

        if ('pid' in self.hparams["regime"]):
            y_pid = (batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]).float()
            y_pid = y_pid.repeat((self.hparams["n_graph_iters"]))
            val_loss = F.binary_cross_entropy_with_logits(output, y_pid.float(), pos_weight = weight)
        else:
            y = batch.y.repeat((self.hparams["n_graph_iters"]))
            val_loss = F.binary_cross_entropy_with_logits(output, y, pos_weight = weight)

        result = pl.EvalResult(checkpoint_on=val_loss)
        result.log('val_loss', val_loss)

        #Edge filter performance
        preds = F.sigmoid(output) > 0.5 #Maybe send to CPU??
        edge_positive = preds.sum().float()

        if ('pid' in self.hparams["regime"]):
            y_pid = batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]
            y_pid = y_pid.repeat((self.hparams["n_graph_iters"]))
            edge_true = y_pid.sum().float()
            edge_true_positive = (y_pid & preds).sum().float()
        else:
            edge_true = y.sum()
            edge_true_positive = (y.bool() & preds).sum().float()

        result.log_dict({'eff': torch.tensor(edge_true_positive/edge_true), 'pur': torch.tensor(edge_true_positive/edge_positive)})

        return result
    
class CheckpointedResAGNN(GNNBase):

    def __init__(self, hparams):
        super().__init__(hparams)
        '''
        Initialise the Lightning Module that can scan over different GNN training regimes
        '''

        # Setup input network
        self.node_encoder = make_mlp(hparams["in_channels"], [hparams["hidden"]],
                                      output_activation=hparams["hidden_activation"],
                                      layer_norm=hparams["layernorm"])
   
        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(2*(hparams["in_channels"] + hparams["hidden"]),
                                     [hparams["hidden"]]*hparams["nb_edge_layer"]+[1],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])

        # The node network computes new node features
        self.node_network = make_mlp((hparams["in_channels"] + hparams["hidden"])*2,
                                     [hparams["hidden"]]*hparams["nb_node_layer"],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])

    def forward(self, x, edge_index):
        
        # Encode the graph features into the hidden space
        input_x = x
        x = self.node_encoder(x)
        x = torch.cat([x, input_x], dim=-1)
        
        start, end = edge_index

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            # Previous hidden state
            x0 = x

            # Compute new edge score
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = checkpoint(self.edge_network, edge_inputs)
            e = torch.sigmoid(e)
            
            # Sum weighted node features coming into each node
#             weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0])
#             weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])
            
            weighted_messages = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]) + scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

            # Compute new node features
#             node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1)
            node_inputs = torch.cat([x, weighted_messages], dim=1)
            x = checkpoint(self.node_network, node_inputs)

            # Residual connection
            x = torch.cat([x, input_x], dim=-1)
            x = x + x0

        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return checkpoint(self.edge_network, clf_inputs).squeeze(-1)

### Lightning Train

In [72]:
with open("../configs/dev_gnn.yaml") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

In [73]:
# torch.cuda.reset_max_memory_allocated()
# wandb_logger = WandbLogger(project=config["project"], group="SubgraphSampling", log_model=True, save_dir = config["wandb_save_dir"])
model = CheckpointedResAGNN(config)
# wandb_logger.log_hyperparams(config)

Data processed


In [74]:
trainer = Trainer(gpus=1, max_epochs=1)
# trainer = Trainer(max_epochs=1) #callbacks=[lr_logger],

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


In [75]:
%%time
trainer.fit(model)

Set SLURM handle signals.

  | Name         | Type       | Params
--------------------------------------------
0 | node_encoder | Sequential | 768   
1 | edge_network | Sequential | 50 K  
2 | node_network | Sequential | 50 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…



HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…


CPU times: user 4.76 s, sys: 2.3 s, total: 7.06 s
Wall time: 7.89 s


1

In [76]:
torch.cuda.max_memory_allocated() / 1024**3

9.939926147460938

### Mixed Precision Model

In [158]:
from pytorch_lightning.utilities import AMPType
from LightningModules.GNN.utils import load_dataset

class MixedGNNBase(LightningModule):

    def __init__(self, hparams):
        super().__init__()
        '''
        Initialise the Lightning Module that can scan over different GNN training regimes
        '''
        # Assign hyperparameters
        self.hparams = hparams
        self.hparams["posted_alert"] = False
        
        # Handle any subset of [train, val, test] data split, assuming that ordering
        input_dirs = [None, None, None]
        input_dirs[:len(hparams["datatype_names"])] = [os.path.join(hparams["input_dir"], datatype) for datatype in hparams["datatype_names"]]
        self.trainset, self.valset, self.testset = [load_dataset(input_dir, hparams["datatype_split"][i], hparams["pt_min"]) for i, input_dir in enumerate(input_dirs)]
        print("Data processed")
        
    def train_dataloader(self):
        if self.trainset is not None:
            return DataLoader(self.trainset, batch_size=1, num_workers=1)
        else:
            return None

    def val_dataloader(self):
        if self.valset is not None:
            return DataLoader(self.valset, batch_size=1, num_workers=1)
        else:
            return None

    def test_dataloader(self):
        if self.testset is not None:
            return DataLoader(self.testset, batch_size=1, num_workers=1)
        else:
            return None
        
    def configure_optimizers(self):
        optimizer = [torch.optim.AdamW(self.parameters(), lr=(self.hparams["lr"]), betas=(0.9, 0.999), eps=1e-08, amsgrad=True)]
#         scheduler = [
#             {
#                 'scheduler': torch.optim.lr_scheduler.ReduceLROnPlateau(optimizer[0], factor=self.hparams["factor"], patience=self.hparams["patience"]),
#                 'monitor': 'val_loss',
#                 'interval': 'epoch',
#                 'frequency': 1
#             }
#         ]
        scheduler = [
            {
                'scheduler': torch.optim.lr_scheduler.StepLR(optimizer[0], step_size=self.hparams["patience"], gamma=self.hparams["factor"]),
                'interval': 'epoch',
                'frequency': 1
            }
        ]
        return optimizer, scheduler
        
        
    def training_step(self, batch, batch_idx):
        
        weight = (torch.tensor(self.hparams["weight"]) if ("weight" in self.hparams)
                      else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()))
        x = batch.x
        print("X out pre:", x.type())
        with torch.cuda.amp.autocast():
            print("X out post:", x.type())
            output = (self(torch.cat([batch.cell_data, batch.x], axis=-1), 
                           batch.edge_index.int()).squeeze()
                      if ('ci' in self.hparams["regime"])
                      else self(batch.x, batch.edge_index.int()).squeeze())

            if ('pid' in self.hparams["regime"]):
                y_pid = (batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]).float()
                loss = F.binary_cross_entropy_with_logits(output, y_pid.float(), pos_weight = weight)
            else:
                loss = F.binary_cross_entropy_with_logits(output, batch.y, pos_weight = weight)
            
            self.log('train_loss', loss)

#         self.manual_backward(loss, opt_a)
#         self.manual_optimizer_step(opt_a)

            return loss

    def shared_evaluation(self, batch, batch_idx):

        weight = (torch.tensor(self.hparams["weight"]) if ("weight" in self.hparams)
                      else torch.tensor((~batch.y_pid.bool()).sum() / batch.y_pid.sum()))

        output = (self(torch.cat([batch.cell_data, batch.x], axis=-1), batch.edge_index).squeeze()
                  if ('ci' in self.hparams["regime"])
                  else self(batch.x, batch.edge_index).squeeze())
        
        truth = (batch.pid[batch.edge_index[0]] == batch.pid[batch.edge_index[1]]).float() if 'pid' in self.hparams["regime"] else batch.y

        loss = F.binary_cross_entropy_with_logits(output, truth.float(), pos_weight = weight)

        #Edge filter performance
        preds = F.sigmoid(output) > self.hparams["edge_cut"]
        edge_positive = preds.sum().float()

        edge_true = truth.sum().float()
        edge_true_positive = (truth.bool() & preds).sum().float()
    
        eff = torch.tensor(edge_true_positive/edge_true)
        pur = torch.tensor(edge_true_positive/edge_positive)
        
        if (eff > 0.99) and (pur > 0.99) and not self.hparams["posted_alert"] and self.hparams["slack_alert"]:
            self.logger.experiment.alert(title="High Performance", 
                        text="Efficiency and purity have both cracked 99%. Great job, Dan! You're having a great Thursday, and I think you've earned a celebratory beer.",
                        wait_duration=timedelta(minutes=60))
            self.hparams["posted_alert"] = True
        
        current_lr = self.optimizers().param_groups[0]['lr']
        self.log_dict({'val_loss': loss, 'eff': eff, 'pur': pur, "current_lr": current_lr})
        
        return {"loss": loss, "preds": preds.cpu().numpy(), "truth": truth.cpu().numpy()}

    def validation_step(self, batch, batch_idx):
        
        outputs = self.shared_evaluation(batch, batch_idx)
            
        return outputs["loss"]
    
    def test_step(self, batch, batch_idx):

        outputs = self.shared_evaluation(batch, batch_idx)
        
        return outputs
    
    
#     def optimizer_step(self, epoch, batch_idx, optimizer, optimizer_idx, optimizer_closure=None, on_tpu=False, using_native_amp=False, using_lbfgs=False):
#         # warm up lr
#         if (self.hparams["warmup"] is not None) and (self.trainer.global_step < self.hparams["warmup"]):
#             lr_scale = min(1., float(self.trainer.global_step + 1) / self.hparams["warmup"])
#             for pg in optimizer.param_groups:
#                 pg['lr'] = lr_scale * self.hparams["lr"]
        
#         # update params
#         print(self.trainer.scaler.get_scale())
#         if self.trainer.amp_backend == AMPType.NATIVE:
#             # native amp does not yet support closures.
#             # TODO: pass the closure to the step ASAP
#             optimizer_closure()
#             self.trainer.scaler.step(optimizer)
#             self.trainer.scaler.update()
#         else:
#             optimizer.step(closure=optimizer_closure)
#         optimizer.zero_grad()




class MixedCheckpointedResAGNN(MixedGNNBase):

    def __init__(self, hparams):
        super().__init__(hparams)
        '''
        Initialise the Lightning Module that can scan over different GNN training regimes
        '''

        # Setup input network
        self.node_encoder = make_mlp(hparams["in_channels"], [hparams["hidden"]],
                                      output_activation=hparams["hidden_activation"],
                                      layer_norm=hparams["layernorm"])
   
        # The edge network computes new edge features from connected nodes
        self.edge_network = make_mlp(2*(hparams["in_channels"] + hparams["hidden"]),
                                     [hparams["hidden"]]*hparams["nb_edge_layer"]+[1],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])

        # The node network computes new node features
        self.node_network = make_mlp((hparams["in_channels"] + hparams["hidden"])*2,
                                     [hparams["hidden"]]*hparams["nb_node_layer"],
                                     layer_norm=hparams["layernorm"],
                                     output_activation=None,
                                     hidden_activation = hparams["hidden_activation"])

    def forward(self, x, edge_index):
        
        # Encode the graph features into the hidden space
        input_x = x
        print("X type:", input_x.type())
        print("Edge type pre:", edge_index.type())
        edge_index = edge_index.long()
        print("Edge type post:", edge_index.type())
        x = self.node_encoder(x)
        x = torch.cat([x, input_x], dim=-1)
        
        start, end = edge_index

        # Loop over iterations of edge and node networks
        for i in range(self.hparams["n_graph_iters"]):
            # Previous hidden state
            x0 = x

            # Compute new edge score
            edge_inputs = torch.cat([x[start], x[end]], dim=1)
            e = checkpoint(self.edge_network, edge_inputs)
            e = torch.sigmoid(e)
            
            # Sum weighted node features coming into each node
#             weighted_messages_in = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0])
#             weighted_messages_out = scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])
            
            weighted_messages = scatter_add(e * x[start], end, dim=0, dim_size=x.shape[0]) + scatter_add(e * x[end], start, dim=0, dim_size=x.shape[0])

            # Compute new node features
#             node_inputs = torch.cat([x, weighted_messages_in, weighted_messages_out], dim=1)
            node_inputs = torch.cat([x, weighted_messages], dim=1)
            x = checkpoint(self.node_network, node_inputs)

            # Residual connection
            x = torch.cat([x, input_x], dim=-1)
            x = x + x0

        # Compute final edge scores; use original edge directions only
        clf_inputs = torch.cat([x[start], x[end]], dim=1)
        return checkpoint(self.edge_network, clf_inputs).squeeze(-1)

### Mixed Precision Lightning Train

In [159]:
torch.cuda.reset_max_memory_allocated()

In [160]:
with open("../configs/dev_gnn.yaml") as f:
        config = yaml.load(f, Loader=yaml.FullLoader)

In [161]:
# torch.cuda.reset_max_memory_allocated()
# wandb_logger = WandbLogger(project=config["project"], group="SubgraphSampling", log_model=True, save_dir = config["wandb_save_dir"])
model = MixedCheckpointedResAGNN(config)
# wandb_logger.log_hyperparams(config)

Data processed


In [162]:
trainer = Trainer(gpus=1, max_epochs=1, amp_level='O2', precision=16)
# trainer = Trainer(max_epochs=1) #callbacks=[lr_logger],

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
Using native 16bit precision.


In [163]:
%%time
trainer.fit(model)

Set SLURM handle signals.

  | Name         | Type       | Params
--------------------------------------------
0 | node_encoder | Sequential | 768   
1 | edge_network | Sequential | 50 K  
2 | node_network | Sequential | 50 K  


HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validation sanity check', layout=Layout…

X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor






HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Training', layout=Layout(flex='2'), max…

X out pre: torch.cuda.FloatTensor
X out post: torch.cuda.FloatTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.IntTensor
Edge type post: torch.cuda.LongTensor
X out pre: torch.cuda.FloatTensor
X out post: torch.cuda.FloatTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.IntTensor
Edge type post: torch.cuda.LongTensor
X out pre: torch.cuda.FloatTensor
X out post: torch.cuda.FloatTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.IntTensor
Edge type post: torch.cuda.LongTensor
X out pre: torch.cuda.FloatTensor
X out post: torch.cuda.FloatTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.IntTensor
Edge type post: torch.cuda.LongTensor
X out pre: torch.cuda.FloatTensor
X out post: torch.cuda.FloatTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.IntTensor
Edge type post: torch.cuda.LongTensor
X out pre: torch.cuda.FloatTensor
X out post: torch.cuda.FloatTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.IntT

HBox(children=(FloatProgress(value=1.0, bar_style='info', description='Validating', layout=Layout(flex='2'), m…

X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor
X type: torch.cuda.FloatTensor
Edge type pre: torch.cuda.LongTensor
Edge type post: torch.cuda.LongTensor

CPU times: user 5.16 s, sys: 2.11 s, total: 7.27 s
Wall time: 8.1 s


1

In [145]:
torch.cuda.max_memory_allocated() / 1024**3

9.865897178649902

## Mixed Precision Base Testing

In [166]:
from torch.cuda.amp import autocast

In [172]:
a_float32 = torch.rand((8, 8), device="cuda")
b_float32 = torch.rand((8, 8), device="cuda")
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
print(a_float32.type())

with autocast():
    # torch.mm is on autocast's list of ops that should run in float16.
    # Inputs are float32, but the op runs in float16 and produces float16 output.
    # No manual casts are required.
    e_float16 = torch.mm(a_float32, b_float32)
    print(e_float16.type())
    print(a_float32.type())
    # Also handles mixed input types
    f_float16 = torch.mm(d_float32, e_float16)
    print(f_float16.type())

# After exiting autocast, calls f_float16.float() to use with d_float32
g_float32 = torch.mm(d_float32, f_float16.float())
print(g_float32.type())

torch.cuda.FloatTensor
torch.cuda.HalfTensor
torch.cuda.FloatTensor
torch.cuda.HalfTensor
torch.cuda.FloatTensor


In [184]:
batch = model.trainset[0]

In [202]:
x = batch.x

In [210]:
x = x.to(device)

In [211]:
x.type()

'torch.cuda.FloatTensor'

In [212]:
with autocast(enabled=True):
    enc = node_encoder(x)
    print(enc.type())
    x2 = x + x
    print(x2.type())

torch.cuda.FloatTensor
torch.cuda.FloatTensor


In [213]:
@staticmethod
@custom_fwd
def forward(a, b):
    return node_encoder(a)

NameError: name 'custom_fwd' is not defined

In [201]:
# The node network computes new node features
hparams = model.hparams
node_network = make_mlp((hparams["in_channels"] + hparams["hidden"])*2,
                             [hparams["hidden"]]*hparams["nb_node_layer"],
                             layer_norm=hparams["layernorm"],
                             output_activation=None,
                             hidden_activation = hparams["hidden_activation"]).to(device)
node_encoder = make_mlp(hparams["in_channels"], [hparams["hidden"]],
                                      output_activation=hparams["hidden_activation"],
                                      layer_norm=hparams["layernorm"]).to(device)

## Testing

In [5]:
run_label = "rm3047br"
wandb_dir = "/global/cscratch1/sd/danieltm/ExaTrkX/wandb_data"
best_run_path = get_best_run(run_label,wandb_dir)

In [6]:
chkpnt = torch.load(best_run_path)
model = CheckpointedResAGNN.load_from_checkpoint(best_run_path)
model = model.to(device)

In [23]:
model.eval()
with torch.no_grad():
    edge_total_positive, edge_total_true, edge_total_true_positive, edge_total_true_ground = 0, 0, 0, 0
    for i, batch in enumerate(model.val_dataloader().dataset[:5]):
        data = batch.to(device)

        output = (model(torch.cat([data.cell_data, data.x], axis=-1), data.edge_index).squeeze()
                  if ('ci' in model.hparams["regime"])
                  else model(data.x, data.edge_index).squeeze())

        #Edge filter performance
        preds = F.sigmoid(output) > 0.9 #Maybe send to CPU??
        edge_positive = preds.sum().float()

        if ('pid' in model.hparams["regime"]):
            y_pid = data.pid[data.edge_index[0]] == data.pid[data.edge_index[1]]
            edge_true = y_pid.sum().float()
            edge_true_positive = (y_pid & preds).sum().float()
        else:
            edge_true = data.y.sum()
            edge_true_ground = data.layerless_true_edges.shape[1]
            edge_true_positive = (data.y.bool() & preds).sum().float()
            
        edge_total_positive += edge_positive
        edge_total_true_positive += edge_true_positive
        edge_total_true += edge_true
        edge_total_true_ground += edge_true_ground
        
        print(i)

    edge_eff = (edge_total_true_positive / max(edge_total_true, 1))
    edge_ground_eff = (edge_total_true_positive / max(edge_total_true_ground, 1))
    edge_pur = (edge_total_true_positive / max(edge_total_positive, 1))

0
1
2
3
4


In [24]:
print("Eff:", edge_eff, "Pur:", edge_pur, "Ground eff:", edge_ground_eff)

Eff: tensor(0.8070, device='cuda:0') Pur: tensor(0.9520, device='cuda:0') Ground eff: tensor(0.6865, device='cuda:0')


## Truth Debugging

In [29]:
sample = model.val_dataloader().dataset[0]

In [51]:
sample[0]

Data(cell_data=[71081, 9], edge_index=[2, 987461], event_file=/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001193, hid=[71081], layerless_true_edges=[2, 85024], layers=[71081], pid=[71081], x=[71081, 3], y=[987461], y_pid=[987461])

In [219]:
sample.layerless_true_edges

tensor([[ 2534,  7691, 13022,  ..., 79374, 89338, 93484],
        [ 7691, 13022, 19078,  ..., 89338, 93484, 96979]])

In [220]:
sample.edge_index

tensor([[     0,    323,    527,  ..., 103304, 103122, 103304],
        [   287,      0,      0,  ..., 103120, 103304, 103131]])

In [221]:
sample.y.sum()

tensor(117942.)

In [222]:
sample.edge_index[:,sample.y.bool()].shape

torch.Size([2, 117942])

In [223]:
sample.layerless_true_edges.shape

torch.Size([2, 123429])

In [224]:
sample.edge_index[:,sample.y.bool()].shape[1]/sample.layerless_true_edges.shape[1]

0.9555452932455095

## TrackML Debugging

### Ground Truth Level

In [205]:
sample = torch.load("/global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/filter_processed/0_pt_cut_endcaps_connected_high_eff/train/1000", map_location="cpu")

In [206]:
event_file = '/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000'
hits, particles, truth = trackml.dataset.load_event(
        event_file, parts=['hits', 'particles', 'truth'])

In [207]:
# Remove noise and assign track_id
hits = hits.merge(truth[['hit_id', 'weight', 'particle_id']], on='hit_id')
hits = hits[hits.particle_id != 0]
hids = sample.hid.cpu().numpy()

In [208]:
truth_graph = (sample.layerless_true_edges).cpu().numpy()
truth_graph = hids[truth_graph]
truth_graph = np.hstack([truth_graph, truth_graph[::-1]])

In [209]:
truth_graph_sp = sp.sparse.coo_matrix(([0.1]*truth_graph.shape[1], (truth_graph[0], truth_graph[1])), shape=(truth_graph.max()+1, truth_graph.max()+1))

In [210]:
clustering = DBSCAN(eps=0.1, metric="precomputed", min_samples=1).fit_predict(truth_graph_sp)

In [211]:
track_list = np.vstack([np.unique(truth_graph), clustering[np.unique(truth_graph)]])
track_list = pd.DataFrame(track_list.T)
track_list.columns = ["hit_id", "track_id"]
score_event(hits, track_list)

1.0000000080094655

### Filter Truth Level

In [212]:
sample = torch.load("/global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/filter_processed/0_pt_cut_endcaps_connected_high_eff/train/1000", map_location="cpu")

In [213]:
sample

Data(cell_data=[103305, 9], edge_index=[2, 1831684], event_file=/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000, hid=[103305], layerless_true_edges=[2, 123429], layers=[103305], pid=[103305], x=[103305, 3], y=[1831684], y_pid=[1831684])

In [214]:
truth_graph = (sample.edge_index[:,sample.y_pid.bool()]).cpu().numpy()
truth_graph = hids[truth_graph]
truth_graph = np.hstack([truth_graph, truth_graph[::-1]])

In [215]:
truth_graph_sp = sp.sparse.coo_matrix(([0.1]*truth_graph.shape[1], (truth_graph[0], truth_graph[1])), shape=(truth_graph.max()+1, truth_graph.max()+1))

From clustering

In [216]:
clustering = DBSCAN(eps=0.1, metric="precomputed", min_samples=1).fit_predict(truth_graph_sp)

In [217]:
track_list = np.vstack([np.unique(truth_graph), clustering[np.unique(truth_graph)]])
track_list = pd.DataFrame(track_list.T)
track_list.columns = ["hit_id", "track_id"]
score_event(hits, track_list)

0.9349574380163364

From ground truth, IF they have a true edge in the filtered set

In [228]:
truth_graph = (sample.edge_index[:,sample.y_pid.bool()]).cpu().numpy()

In [229]:
track_list = np.vstack([hids[np.unique(truth_graph)], sample.pid[np.unique(truth_graph)]])
track_list = pd.DataFrame(track_list.T)
track_list.columns = ["hit_id", "track_id"]
score_event(hits, track_list)

0.9592538354264974

### Noise Robustness

In [148]:
sample = torch.load("/global/cscratch1/sd/danieltm/ExaTrkX/trackml_processed/filter_processed/0_pt_cut_endcaps_connected_high_eff/train/1000", map_location="cpu")

In [149]:
event_file = '/global/cscratch1/sd/danieltm/ExaTrkX/trackml/train_all/event000001000'
hits, particles, truth = trackml.dataset.load_event(
        event_file, parts=['hits', 'particles', 'truth'])

In [150]:
# Assign track_id
hits = hits.merge(truth[['hit_id', 'weight', 'particle_id']], on='hit_id')
hids = sample.hid.cpu().numpy()

In [151]:
truth_graph = (sample.layerless_true_edges).cpu().numpy()
truth_graph = hids[truth_graph]
truth_graph = np.hstack([truth_graph, truth_graph[::-1]])

In [152]:
truth_graph_sp = sp.sparse.coo_matrix(([0.1]*truth_graph.shape[1], (truth_graph[0], truth_graph[1])), shape=(truth_graph.max()+1, truth_graph.max()+1))

In [153]:
clustering = DBSCAN(eps=0.1, metric="precomputed", min_samples=1).fit_predict(truth_graph_sp)

In [154]:
track_list = np.vstack([np.unique(truth_graph), clustering[np.unique(truth_graph)]])
track_list = pd.DataFrame(track_list.T)
track_list.columns = ["hit_id", "track_id"]
score_event(hits, track_list)

1.0000000080094655

In [175]:
noise_idx = hits.hit_id[~hits.hit_id.isin(track_list.hit_id)]

In [170]:
random_noise = np.random.choice(hits.particle_id, (len(noise_idx),))

In [179]:
noise_idx = pd.DataFrame(noise_idx).assign(track_id=random_noise)

In [184]:
noise_joined = pd.concat([track_list, noise_idx])

In [185]:
score_event(hits, noise_joined)

1.0000000080094655

In [186]:
noise_joined

Unnamed: 0,hit_id,track_id
0,2,2
1,4,4
2,5,5
3,6,6
4,7,7
...,...,...
120908,120909,0
120912,120913,801644513243168768
120925,120926,589977461060534272
120928,120929,396324463789998080
