A notebook to run the entire pipeline using the ground-truth noise classifier during the graph-contruction phase

In [1]:
from functools import partial
from pathlib import Path

import torch

from gnn_tracking.training.tc import TCModule
from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt
from gnn_tracking.models.track_condensation_networks import GraphTCNForMLGCPipeline
from gnn_tracking.metrics.losses.metric_learning import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner
from pytorch_lightning import Trainer
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.utils.versioning import assert_version_geq

from torch_geometric.data import Data
from torch import nn

assert_version_geq("23.12.0")

# Configure the data

In [2]:
# data_dir = Path.cwd().resolve().parent.parent / "test-data" / "data" / "point_clouds" / "v8"
data_dir = Path("/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1")
assert data_dir.is_dir()

In [3]:
dm = TrackingDataModule(
    train=dict(
        dirs=[data_dir],
        stop=1,
    ),
    val=dict(
        dirs=[data_dir],
        start=1,
        stop=2,
    ),
    identifier="point_clouds_v8"
)

In [None]:
model = GraphTCNForMLGCPipeline(
    node_indim=22,
    edge_indim=44,
    h_dim=192,
    e_dim=192,
    hidden_dim=192,
    h_outdim=24,
    L_hc=5,
    alpha_latent=0.5,
    n_embedding_cords=8
)

In [None]:
preproc = MLGraphConstructionFromChkpt(
    ml_chkpt_path="",
    max_num_neighbors=10,
    max_radius=1,
    use_embedding_features=True,
    build_edge_features=True
)

In [None]:
loss_fct = GraphConstructionHingeEmbeddingLoss(
    lw_repulsive=0.05,
    pt_thld=0.9,
    max_num_neighbors=256,
    p_attr=2,
    p_rep=2,
    r_emb=1
)

In [None]:
cluster_scanner = DBSCANHyperParamScanner(
    n_trials=60,
    n_jobs=6,
    keep_best=30
)

optimizer = torch.optim.Adam(lr=8*10^-4)
scheduler = torch.optim.lr_scheduler.LinearLR(start_factor=1,
                                              end_factor=0.1,
                                              total_iters=50
                                              )

In [None]:
lmodel = TCModule(
    model=model,
    preproc=preproc,
    loss_fct=loss_fct,
    cluster_scanner=cluster_scanner,
    optimizer=optimizer,
    scheduler=scheduler
)

# Training

In [None]:
trainer = Trainer(
    max_epochs=1000,
    accelerator="gpu",
    log_every_n_steps=1,
    callbacks=[PrintValidationMetrics()],
)
trainer.fit(model=lmodel, datamodule=dm)