In [None]:
import numpy as np
from sklearn import metrics

import torch
from numpy.random import uniform
from torch_geometric.loader import DataLoader
from gnn_tracking.preprocessing.point_cloud_builder import PointCloudBuilder
from gnn_tracking.utils.plotting import GraphPlotter
from gnn_tracking.utils.plotting import PointCloudPlotter
from gnn_tracking.models.track_condensation_networks import GraphTCN
from pathlib import Path
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner

%load_ext autoreload
%autoreload 2

In [None]:
from gnn_tracking.utils.plotting import EventPlotter

# we'll use n_evts * n_sectors = 640 graphs
n_evts, n_sectors = 10, 64
indir = "/tigress/jdezoort/codalab/train_1"
# indir='/home/kl5675/Documents/22/git_sync/gnn_tracking/src/gnn_tracking/test_data'
# event_plotter = EventPlotter(indir=indir)
# event_plotter.plot_ep_rv_uv(evtid=21289)

In [None]:
# build point clouds for each sector in the pixel layers only
# pc_builder = PointCloudBuilder(indir=indir, outdir=str(Path("~/data/gnn_tracking/point_clouds").expanduser()),
#                                n_sectors=n_sectors, pixel_only=True, redo=False, measurement_mode=False, thld=0.9)
# pc_builder.process(n=10, verbose=False)

In [None]:
# each point cloud is a PyG Data object
# point_cloud = pc_builder.data_list[0]
# good = ((point_cloud.sector>-1) & (point_cloud.particle_id>0) &
#         (point_cloud.pt > 0.5))

In [None]:
# point_cloud

In [None]:
# visualize the secto
# rs in each event and an overlapped ('extended') sector
# pc_plotter = PointCloudPlotter(str(Path("~/data/gnn_tracking/point_clouds").expanduser()),
#                                n_sectors=pc_builder.n_sectors)
# pc_plotter.plot_ep_rv_uv_all_sectors(21289)
# pc_plotter.plot_ep_rv_uv_with_boundary(21289, 18,
#                                        pc_builder.sector_di,
#                                        pc_builder.sector_ds)

In [None]:
! mkdir /home/kl5675/data/gnn_tracking/graphs

In [None]:
# we can build graphs on the point clouds using geometric cuts

from gnn_tracking.graph_construction.graph_builder import GraphBuilder

graph_builder = GraphBuilder(
    str(Path("~/data/gnn_tracking/point_clouds").expanduser()),
    str(Path("~/data/gnn_tracking/graphs").expanduser()),
    redo=False,
)
graph_builder.process(n=None)

In [None]:
! ls  /home/kl5675/data/gnn_tracking/

In [None]:
# the graph plotter shows the true and false edges constructed by the builder

# graph_plotter = GraphPlotter()
# graph = graph_builder.data_list[0]
# print(graph)
# evtid, s = graph.evtid.item(), graph.s.item()

# takes a minute to run, but cool visual!
# graph_plotter.plot_rz(graph_builder.data_list[0],
#          f'event{evtid}_s{s}')

In [None]:
from gnn_tracking.training.graph_tcn_trainer import GraphTCNTrainer

# use cuda (gpu) if possible, otherwise fallback to cpu
use_cuda = torch.cuda.is_available()
device = torch.device("cuda" if use_cuda else "cpu")
print(f"Utilizing {device}")

# use reference graph to get relevant dimensions
g = graph_builder.data_list[0]
node_indim = g.x.shape[1]
edge_indim = g.edge_attr.shape[1]
hc_outdim = 2  # output dim of latent space

# partition graphs into train, test, val splits
graphs = graph_builder.data_list
n_graphs = len(graphs)
rand_array = uniform(low=0, high=1, size=n_graphs)
train_graphs = [g for i, g in enumerate(graphs) if (rand_array <= 0.7)[i]]
test_graphs = [
    g for i, g in enumerate(graphs) if ((rand_array > 0.7) & (rand_array <= 0.9))[i]
]
val_graphs = [g for i, g in enumerate(graphs) if (rand_array > 0.9)[i]]

# build graph loaders
params = {"batch_size": 1, "shuffle": True, "num_workers": 1}

In [None]:
train_loader = DataLoader(list(train_graphs), **params)

In [None]:
params = {"batch_size": 2, "shuffle": False, "num_workers": 2}
test_loader = DataLoader(list(test_graphs), **params)
val_loader = DataLoader(list(val_graphs), **params)
loaders = {"train": train_loader, "test": test_loader, "val": val_loader}
print("Loader sizes:", [(k, len(v)) for k, v in loaders.items()])

# set up a model and trainer

In [None]:
torch.manual_seed(0)
import numpy as np

np.random.seed(0)
import random

random.seed(0)

In [None]:
extra_metrics = {
    "homogeneity": metrics.homogeneity_score,
    "completeness": metrics.completeness_score,
}

In [None]:
def clustering(graphs, truth, sectors, epoch):
    if epoch < 5:
        return None
    dbss = DBSCANHyperParamScanner(
        graphs=graphs,
        truth=truth,
        sectors=sectors,
        guiding_metric=metrics.v_measure_score,
        extra_metrics=extra_metrics,
    )
    return dbss.scan(n_jobs=1, n_trials=100)

In [None]:
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.utils.losses import EdgeWeightBCELoss, PotentialLoss, BackgroundLoss
import optuna

optuna.logging.set_verbosity(optuna.logging.WARNING)

q_min, sb = 0.01, 0.1
loss_functions = {
    "edge": EdgeWeightBCELoss().to(device),
    "potential": PotentialLoss(q_min=q_min, device=device),
    "background": BackgroundLoss(device=device, sb=sb),
    # "object": ObjectLoss(device=device, mode='efficiency')
}

loss_weights = {
    # everything that's not mentioned here will be 1
    "edge": 5,
    "potential_attractive": 10,
    "potential_repulsive": 1,
    "background": 1,
    # "object": 1/250000,
}

# set up a model and trainer
model = GraphTCN(node_indim, edge_indim, hc_outdim, hidden_dim=64)
model_parameters = filter(lambda p: p.requires_grad, model.parameters())
n_params = sum([np.prod(p.size()) for p in model_parameters])
print("number trainable params:", n_params)

In [None]:
checkpoint = torch.load(Path("~/data/gnn_tracking/model.pt").expanduser())
model.load_state_dict(checkpoint["model_state_dict"])

In [None]:
test_fcts = {"dbscan": clustering}

trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.0001,
    loss_weights=loss_weights,
    device=device,
    cluster_functions=test_fcts,
)
print(trainer.loss_functions)

In [None]:
trainer.optimizer.load_state_dict(checkpoint["optimizer_state_dict"])
trainer._epoch = checkpoint["epoch"]

In [None]:
ns = [4558, 4265, 4532, 4596, 4314, 4222, 4888, 4640, 4797, 4565, 4883]

In [None]:
import warnings

warnings.filterwarnings("ignore")
trainer.test_step()

In [None]:
torch.save(
    {
        "epoch": 5,
        "model_state_dict": model.state_dict(),
        "optimizer_state_dict": trainer.optimizer.state_dict(),
    },
    Path("~/data/gnn_tracking/model.pt").expanduser(),
)

In [None]:
with torch.no_grad():
    loader = loaders["val"]
    for _batch_idx, data in enumerate(loader):
        print(data.sector.unique())