## Training an XGBoost model for noise classification 

In [1]:
from functools import partial
import os

import torch
import xgboost

from gnn_tracking.training.tc import TCModule
from gnn_tracking.training.ml import MLModule
from gnn_tracking.training.classification import NodeClassifierModule
from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt, GraphConstructionFCNN
from gnn_tracking.models.track_condensation_networks import GraphTCNForMLGCPipeline
from gnn_tracking.graph_construction.k_scanner import GraphConstructionKNNScanner
from gnn_tracking.metrics.losses.metric_learning import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.metrics.losses.classification import CEL
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.training.callbacks import PrintValidationMetrics, ExpandWandbConfig
from gnn_tracking.utils.versioning import assert_version_geq

from torch_geometric.data import Data
from torch import nn

assert_version_geq("23.12.0")

In [2]:
data_path_pixel = "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/"
data_paths_pixel = os.listdir(data_path_pixel)
data_paths_pixel.sort()
data_paths_pixel = list(map(lambda x: data_path_pixel + x, data_paths_pixel))

In [3]:
dm = TrackingDataModule(
    train=dict(
        dirs=data_paths_pixel[1:-1],
        sample_size=900
    ),
    val=dict(
        dirs=[data_paths_pixel[-1]],
        start=0,
        stop=4,
    ),
    identifier="point_clouds_v8",
)
dm.setup(stage='fit')

[32m[06:45:32] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[06:45:32] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_8/data28999_s0.pt[0m
[32m[06:45:32] INFO: DataLoader will load 4 graphs (out of 1000 available).[0m
[36m[06:45:32] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29003_s0.pt[0m


In [4]:
model = GraphConstructionFCNN(in_dim=14, out_dim=2, depth=6, hidden_dim=256, classification=True)

In [5]:
model(dm.datasets["train"][0])["H"].shape

  out = softmax(out)


torch.Size([66114, 2])

In [6]:
torch.vstack([dm.datasets["train"][0].particle_id == 0, dm.datasets["train"][0].particle_id != 0])

tensor([[False, False, False,  ..., False,  True, False],
        [ True,  True,  True,  ...,  True, False,  True]])

In [7]:
torch.tensor(list(zip(dm.datasets["train"][0].particle_id == 0, dm.datasets["train"][0].particle_id != 0))).type(torch.LongTensor)

tensor([[0, 1],
        [0, 1],
        [0, 1],
        ...,
        [0, 1],
        [1, 0],
        [0, 1]])

In [7]:
CEL(
        weight=torch.tensor([0.9348, 0.0652]).to('cuda')
)

  self._loss_fct = CrossEntropyLoss(weight=torch.tensor(weight))


CEL(
  (_loss_fct): CrossEntropyLoss()
)

In [7]:
lmodel = NodeClassifierModule(
    model=model,
    loss_fct=CEL(
        weight=torch.tensor([0.9348, 0.0652]).to('cuda')
    ),
    optimizer=partial(torch.optim.Adam, lr=1*1e-3),
)

  self._loss_fct = CrossEntropyLoss(weight=torch.tensor(weight))


In [22]:
lmodel = NodeClassifierModule.load_from_checkpoint("/home/aj2239/aryaman-gnn-tracking-experiments/notebooks/lightning_logs/version_55064933/checkpoints/epoch=1-step=1800.ckpt")

[36m[07:35:22] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[07:35:22] DEBUG: Getting class CEL from module gnn_tracking.metrics.losses.classification[0m
  self._loss_fct = CrossEntropyLoss(weight=torch.tensor(weight))


In [23]:
trainer = Trainer(
    max_epochs=5,
    accelerator="gpu",
    log_every_n_steps=1,
    callbacks=[
        PrintValidationMetrics(),
    ],
)
trainer.fit(model=lmodel, datamodule=dm)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m[07:36:14] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[07:36:14] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_8/data28999_s0.pt[0m
[32m[07:36:14] INFO: DataLoader will load 4 graphs (out of 1000 available).[0m
[36m[07:36:14] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29003_s0.pt[0m
/scratch/gpfs/aj2239/micromamba/envs/gnn/lib/python3.11/site-packages/pytorch_lightning/callbacks/model_checkpoint.py:639: Checkpoint directory /home/aj2239/aryaman-gnn-tracking-experiments/not

Sanity Checking: |          | 0/? [00:00<?, ?it/s]

/scratch/gpfs/aj2239/micromamba/envs/gnn/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.
  out = softmax(out)
/scratch/gpfs/aj2239/micromamba/envs/gnn/lib/python3.11/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=3` in the `DataLoader` to improve performance.


Training: |          | 0/? [00:00<?, ?it/s]

Validation: |          | 0/? [00:00<?, ?it/s]

[3m         Validation epoch=0          [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mMetric         [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1mError[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ fp_pt           │ 0.00000 │   nan │
│ total_train     │ 0.06516 │   nan │
│ total_val       │ 0.06598 │   nan │
│ total_val_epoch │ 0.06598 │   nan │
└─────────────────┴─────────┴───────┘



Validation: |          | 0/? [00:00<?, ?it/s]

[3m         Validation epoch=1          [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mMetric         [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1mError[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ fp_pt           │ 0.00000 │   nan │
│ total_train     │ 0.05996 │   nan │
│ total_val       │ 0.06308 │   nan │
│ total_val_epoch │ 0.06308 │   nan │
└─────────────────┴─────────┴───────┘



Validation: |          | 0/? [00:00<?, ?it/s]

[3m         Validation epoch=2          [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mMetric         [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1mError[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ fp_pt           │ 0.00000 │   nan │
│ total_train     │ 0.06132 │   nan │
│ total_val       │ 0.06132 │   nan │
│ total_val_epoch │ 0.06132 │   nan │
└─────────────────┴─────────┴───────┘



Validation: |          | 0/? [00:00<?, ?it/s]

[3m         Validation epoch=3          [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mMetric         [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1mError[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ fp_pt           │ 0.00000 │   nan │
│ total_train     │ 0.06190 │   nan │
│ total_val       │ 0.06040 │   nan │
│ total_val_epoch │ 0.06040 │   nan │
└─────────────────┴─────────┴───────┘



Validation: |          | 0/? [00:00<?, ?it/s]

`Trainer.fit` stopped: `max_epochs=5` reached.


[3m         Validation epoch=4          [0m
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mMetric         [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1mError[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━┩
│ fp_pt           │ 0.00000 │   nan │
│ total_train     │ 0.06239 │   nan │
│ total_val       │ 0.05839 │   nan │
│ total_val_epoch │ 0.05839 │   nan │
└─────────────────┴─────────┴───────┘



In [27]:
from sklearn.metrics import jaccard_score, accuracy_score, roc_curve, roc_auc_score

data = dm.datasets["val"][0]
y = torch.tensor(list(zip(data.particle_id == 0, data.particle_id != 0))).type(torch.LongTensor)
roc_auc_score(model(data)["H"].detach() > 0.5, y.detach())

  out = softmax(out)


0.5588074027539947

In [18]:
model(data)['H'] > 0.5

tensor([[False,  True],
        [False,  True],
        [False,  True],
        ...,
        [ True, False],
        [ True, False],
        [ True, False]])

In [16]:
y.type(torch.BoolTensor)

tensor([[False,  True],
        [False,  True],
        [False,  True],
        ...,
        [False,  True],
        [False,  True],
        [False,  True]])

In [10]:
model = NodeClassifierModule.load_from_checkpoint("/home/aj2239/aryaman-gnn-tracking-experiments/notebooks/lightning_logs/version_55045227/checkpoints/epoch=1-step=1800.ckpt")

[36m[10:27:54] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[10:27:54] DEBUG: Getting class CEL from module gnn_tracking.metrics.losses.classification[0m
  self._loss_fct = CrossEntropyLoss(weight=torch.tensor(weight))


In [20]:
model = model.to('cuda')
(model(dm.datasets['val'][0].to('cuda'))["H"][:, 0] == 1).shape
torch.sum(model(dm.datasets['val'][0].to('cuda'))["H"][:, 0] == 1)

torch.Size([59357])

In [22]:
torch.sum(dm.datasets['val'][0].particle_id == 0)

tensor(3650)

In [None]:
total_particles = list(map(lambda x: x.num_nodes, dm_list))

In [None]:
total_noise = torch.tensor(total_noise)
total_particles = torch.tensor(total_particles)

In [None]:
torch.sum(total_noise)/torch.sum(total_particles)

In [None]:
model = torch.load("/home/aj2239/aryaman-gnn-tracking-experiments/notebooks/lightning_logs/version_55027043/checkpoints/epoch=99-step=90000.ckpt")