# Experiments for FD data with new architectures

In [1]:
import os

from gnn_tracking.utils.loading import TrackingDataModule

In [2]:
from functools import partial
import os

import torch

from gnn_tracking.training.tc import TCModule
from gnn_tracking.training.ml import MLModule
from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt, GraphConstructionFCNN, NoiseClassifierModel, HeterogeneousFCNN
from gnn_tracking.graph_construction.k_scanner import GraphConstructionKNNScanner
from gnn_tracking.models.track_condensation_networks import GraphTCNForMLGCPipeline
from gnn_tracking.metrics.losses.metric_learning import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner
from pytorch_lightning import Trainer
from pytorch_lightning.callbacks import EarlyStopping, ModelCheckpoint, LearningRateMonitor
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.training.callbacks import PrintValidationMetrics, ExpandWandbConfig
from gnn_tracking.utils.versioning import assert_version_geq

from torch_geometric.data import Data
from torch import nn

assert_version_geq("23.12.0")

## Configuring the data

In [3]:
data_path_fd = "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/"
data_paths_fd = os.listdir(data_path_fd)
data_paths_fd.sort()
data_paths_fd = list(map(lambda x: data_path_fd + x, data_paths_fd))

In [4]:
dm_fd = TrackingDataModule(
    train=dict(
        dirs=data_paths_fd[1:-1],
        sample_size=900
    ),
    val=dict(
        dirs=[data_paths_fd[-1]],
        start=0,
        stop=4,
    ),
    identifier="point_clouds_v10",
)
dm_fd.setup(stage='fit')

[32m[08:21:44] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[08:21:44] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_8/data28999_s0.pt[0m
[32m[08:21:44] INFO: DataLoader will load 4 graphs (out of 1000 available).[0m
[36m[08:21:44] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v10/part_9/data29003_s0.pt[0m


In [7]:
model = HeterogeneousFCNN(14, 256, 8, 6, 14, 256, 8, 6)

# GC-Phase

In [None]:
model = HeterogeneousFCNN(14, 256, 8, 6, 14, 256, 8, 6)

In [None]:
lmodel = MLModule(
    model=model,
    loss_fct=GraphConstructionHingeEmbeddingLoss(
        lw_repulsive=0.06,
        max_num_neighbors=256,
    ),
    optimizer=partial(torch.optim.Adam, lr=7*1e-4),
    gc_scanner=GraphConstructionKNNScanner(ks=list(range(1, 11)))
)



In [None]:
from pytorch_lightning.loggers import TensorBoardLogger, WandbLogger
from wandb_osh.lightning_hooks import TriggerWandbSyncLightningCallback
from gnn_tracking.utils.nomenclature import random_trial_name

name = random_trial_name()

wandb_logger_gc = WandbLogger(
    project="aryaman-gnn-experiments",
    group="noise-classification-gc",
    offline=True,
    version=name,
    tags=["noise-classification-gc-with-scanner"],
)

tb_logger_gc = TensorBoardLogger(".", version=name)

In [None]:
trainer_gc = Trainer(
    max_epochs=100,
    accelerator="gpu",
    log_every_n_steps=1,
    callbacks=[
        TriggerWandbSyncLightningCallback(),
        PrintValidationMetrics(),
    ],
    logger=[
        wandb_logger_gc,
        tb_logger_gc,
    ],
)
trainer_gc.fit(model=lmodel, datamodule=dm_fd)