In [None]:
%load_ext autoreload
%autoreload 2

import sys
import torch
import numpy as np
import pandas as pd
from numpy.random import uniform
from torch_geometric.loader import DataLoader

In [None]:
from gnn_tracking.utils.plotting import EventPlotter

n_evts, n_sectors = 20, 32
savefig = False
indir='/tigress/jdezoort/codalab/train_1'
event_plotter = EventPlotter(indir=indir)
event_plotter.plot_ep_rv_uv(evtid=21289, savefig=savefig,
                            filename='../plots/full_event.pdf')

In [None]:
from gnn_tracking.preprocessing.point_cloud_builder import PointCloudBuilder

# build point clouds for each sector in the pixel layers only
pc_builder = PointCloudBuilder(indir='/tigress/jdezoort/codalab/train_1',
                               outdir='../point_clouds/',
                               n_sectors=n_sectors, pixel_only=True, 
                               redo=False, measurement_mode=False,
                               sector_di=0, sector_ds=1.3, thld=0.9)
pc_builder.process(n=n_evts, verbose=False)

In [None]:
# each point cloud is a PyG Data object 
point_cloud = pc_builder.data_list
pc_builder.get_measurements()

In [None]:
from gnn_tracking.utils.plotting import PointCloudPlotter

# visualize the sectors in each event and an overlapped ('extended') sector
pc_plotter = PointCloudPlotter('../point_clouds', 
                               n_sectors=pc_builder.n_sectors)
pc_plotter.plot_ep_rv_uv_all_sectors(21289, savefig=savefig, filename='../plots/point_cloud.pdf')
pc_plotter.plot_ep_rv_uv_with_boundary(21289, 18, 
                                       pc_builder.sector_di,
                                       pc_builder.sector_ds,
                                       savefig=savefig, 
                                       filename='../plots/point_cloud_extended.pdf')

In [None]:
# we can build graphs on the point clouds using geometric cuts
from gnn_tracking.graph_construction.graph_builder import GraphBuilder

graph_builder = GraphBuilder(indir='../point_clouds/', outdir='../graphs', 
                             redo=False, measurement_mode=False, 
                             phi_slope_max=0.0035, z0_max=200, dR_max=2.3)
graph_builder.process(verbose=True, n=n_evts*n_sectors)
graph_builder.get_measurements()

In [None]:
from gnn_tracking.utils.plotting import GraphPlotter

# the graph plotter shows the true and false edges constructed by the builder
graph_plotter = GraphPlotter(indir='../graphs')
graph = graph_builder.data_list[0]
print(graph)
evtid, s = graph.evtid.item(), graph.s.item()

#graph_plotter.plot_rz(graph_builder.data_list[0], 
#                      f'event{evtid}_s{s}', 
#                      scale=np.array([1,1,1]))

graph_plotter.plot_ep_rz_uv(graph, sector=s, name=f'data{evtid}_s{s}',
                            savefig=savefig, filename='../plots/graphs.pdf')

In [None]:
from gnn_tracking.models.track_condensation_networks import PointCloudTCN, GraphTCN
from gnn_tracking.training.graph_tcn_trainer import GraphTCNTrainer
from gnn_tracking.utils.losses import (
    EdgeWeightLoss,
    PotentialLoss,
    BackgroundLoss,
    ObjectLoss,
)

# use cuda (gpu) if possible, otherwise fallback to cpu
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f'Utilizing {device}')

# use reference graph to get relevant dimensions 
g = graph_builder.data_list[0]
node_indim = g.x.shape[1]
edge_indim = g.edge_attr.shape[1]
hc_outdim = 2 # output dim of latent space  

# partition graphs into train, test, val splits
graphs = graph_builder.data_list
n_graphs = len(graphs)
rand_array = uniform(low=0, high=1, size=n_graphs)
train_graphs = [g for i, g in enumerate(graphs) if (rand_array<=0.7)[i]]
test_graphs = [g for i, g in enumerate(graphs) if ((rand_array>0.7) & (rand_array<=0.9))[i]]
val_graphs = [g for i, g in enumerate(graphs) if (rand_array>0.9)[i]]

# build graph loaders
params = {'batch_size': 1, 'shuffle': True, 'num_workers': 2}
train_loader = DataLoader(list(train_graphs), **params)
params = {'batch_size': 1, 'shuffle': False, 'num_workers': 2}
test_loader = DataLoader(list(test_graphs), **params)
val_loader = DataLoader(list(val_graphs), **params)
loaders = {'train': train_loader, 'test': test_loader,
           'val': val_loader}
print('Loader sizes:', [(k, len(v)) for k, v in loaders.items()])

# build loss function dictionary
q_min, sb = 0.01, 0.1
loss_functions = {
    "edge": EdgeWeightLoss().to(device),
    "potential": PotentialLoss(q_min=q_min, device=device),
    "background": BackgroundLoss(device=device, sb=sb),
    "object": ObjectLoss(device=device, mode='efficiency')
}

loss_weights = {
    # everything that's not mentioned here will be 1
    "edge": 5,
    "potential_repulsive": 10,
    "background": 10,
    "object": 25,
}

# set up a model and trainer
model = GraphTCN(node_indim, edge_indim, hc_outdim, hidden_dim=64)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
n_params = sum([np.prod(p.size()) for p in model_parameters])
print('number trainable params:', n_params)
trainer = GraphTCNTrainer(model=model, loaders=loaders, loss_functions=loss_functions,
                          loss_weights=loss_weights, device=device)
print(trainer.loss_functions)

In [None]:
torch.autograd.set_detect_anomaly(True)
import warnings
warnings.filterwarnings('ignore')
trainer.train()