# Experiments for FD data with new architectures

In [1]:
import os

from gnn_tracking.utils.loading import TrackingDataModule

In [2]:
from functools import partial
import os

import torch

from gnn_tracking.training.tc import TCModule
from gnn_tracking.training.ml import MLModule
from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt, GraphConstructionFCNN, NoiseClassifierModel, HeterogeneousFCNN
from gnn_tracking.models.track_condensation_networks import GraphTCNForMLGCPipeline
from gnn_tracking.metrics.losses.metric_learning import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.training.callbacks import PrintValidationMetrics, ExpandWandbConfig
from gnn_tracking.utils.versioning import assert_version_geq

from torch_geometric.data import Data
from torch import nn

assert_version_geq("23.12.0")

## Configuring the data

In [3]:
data_path_fd = "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/"
data_paths_fd = os.listdir(data_path_fd)
data_paths_fd.sort()
data_paths_fd = list(map(lambda x: data_path_fd + x, data_paths_fd))

In [4]:
dm_fd = TrackingDataModule(
    train=dict(
        dirs=data_paths_fd[1:-1],
        sample_size=900
    ),
    val=dict(
        dirs=[data_paths_fd[-1]],
        start=0,
        stop=4,
    ),
    identifier="point_clouds_v10",
)
dm_fd.setup(stage='fit')

[32m[08:21:44] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[08:21:44] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_8/data28999_s0.pt[0m
[32m[08:21:44] INFO: DataLoader will load 4 graphs (out of 1000 available).[0m
[36m[08:21:44] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_9/data29003_s0.pt[0m


In [14]:
dm_fd.datasets['train']

TrackingDataset(7743)

In [5]:
d = dm_fd.datasets['train'][324]

In [17]:
d

Data(x=[120328, 14], edge_index=[2, 537473], y=[0], layer=[120328], particle_id=[120328], pt=[120328], reconstructable=[120328], sector=[120328], eta=[120328], n_hits=[120328], n_layers_hit=[120328])

In [25]:
d.layer.unique(return_counts=True)

(tensor([ 0,  1,  2,  3,  4,  5,  6,  7,  8,  9, 10, 11, 12, 13, 14, 15, 16, 17,
         18, 19, 20, 21, 22, 23, 24, 25, 26, 27, 28, 29, 30, 31, 32, 33, 34, 35,
         36, 37, 38, 39, 40, 41, 42, 43, 44, 45, 46, 47]),
 tensor([1583, 1898, 2264, 2633, 2992, 3074, 3150, 8954, 7387, 6580, 5991, 3063,
         3006, 2896, 2569, 2226, 1816, 1529, 1148, 1186, 1241, 1291, 1339, 1388,
         7081, 6774, 6179, 5726, 1380, 1400, 1315, 1258, 1193, 1184,  422,  408,
          399,  407,  499,  577, 5188, 5063,  528,  429,  430,  407,  432,  445]))

In [6]:
pixel_mask = torch.isin(d.layer, torch.tensor(list(range(18))))

model_pixel = GraphConstructionFCNN(in_dim=14, out_dim=8, depth=6, hidden_dim=256)
model_strip = GraphConstructionFCNN(in_dim=14, out_dim=8, depth=6, hidden_dim=256)

d_pixel = d.subgraph(pixel_mask)
d_strip = d.subgraph(~pixel_mask)

embed_pixel = model_pixel(d_pixel)["H"]
embed_strip = model_strip(d_strip)["H"]

embed = torch.vstack([embed_pixel, embed_strip])

In [10]:
empty = torch.zeros_like(embed)

empty[pixel_mask] = embed_pixel
empty[~pixel_mask] = embed_strip

In [7]:
model = HeterogeneousFCNN(14, 256, 8, 6, 14, 256, 8, 6)

In [8]:
model(d)["H"].shape

torch.Size([120328, 8])

In [9]:
model.hparams

"alpha_pix":        0.6
"alpha_strip":      0.6
"depth_pix":        6
"depth_strip":      6
"hidden_dim_pix":   256
"hidden_dim_strip": 256
"in_dim_pix":       14
"in_dim_strip":     14
"out_dim_pix":      8
"out_dim_strip":    8

# GC-Phase