In [1]:
%load_ext autoreload
%autoreload 2

from pytorch_lightning.core.mixins import HyperparametersMixin
from torch_geometric.data import Data
from pytorch_lightning import Trainer

from gnn_tracking.metrics.losses import PotentialLoss, BackgroundLoss
import torch
from functools import partial
from gnn_tracking.training.tc import TCModule
from gnn_tracking.utils.loading import TrackingDataModule

## 1. Configure data

In [38]:
dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/processed/"
        ],
        stop=150,
        max_sample_size=800,
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/processed/"
        ],
        start=150,
        stop=155,
    ),
    cpus=3,
    # could also configure a 'test' set here
)

In [30]:
from gnn_tracking.models.resin import ResIN
from torch import nn, Tensor
from gnn_tracking.models.track_condensation_networks import ModularGraphTCN


class LSGraphTCN(nn.Module, HyperparametersMixin):
    def __init__(
        self,
        *,
        node_indim: int,
        edge_indim: int,
        h_dim=5,
        e_dim=4,
        h_outdim=2,
        hidden_dim=40,
        L_hc=3,
        alpha_hc: float = 0.5,
    ):
        super().__init__()
        self.save_hyperparameters()
        hc_in = ResIN(
            node_dim=h_dim,
            edge_dim=e_dim,
            object_hidden_dim=hidden_dim,
            relational_hidden_dim=hidden_dim,
            alpha=alpha_hc,
            n_layers=L_hc,
        )
        self._gtcn = ModularGraphTCN(
            hc_in=hc_in,
            node_indim=node_indim,
            edge_indim=edge_indim,
            h_dim=h_dim,
            e_dim=e_dim,
            h_outdim=h_outdim,
            hidden_dim=hidden_dim,
        )

    def forward(
        self,
        data: Data,
    ) -> dict[str, Tensor]:
        return self._gtcn.forward(data=data)

In [31]:
model = LSGraphTCN(
    node_indim=9, edge_indim=3, h_dim=128, e_dim=128, h_outdim=12, L_hc=3
)

  rank_zero_warn(


## 3. Configure loss functions and weights

In [35]:
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner


def n_trials(epoch: int) -> int:
    # if epoch < 10:
    #     return 0
    if epoch % 3 == 0:
        return 6
    else:
        return 0


# TC for track condensation
lmodel = TCModule(
    model=model,
    potential_loss=PotentialLoss(
        radius_threshold=1.0,
    ),
    background_loss=BackgroundLoss(),
    lw_repulsive=1.0,
    lw_background=0.1,
    optimizer=partial(torch.optim.Adam, lr=7.5e-4),
    cluster_scanner=DBSCANHyperParamScanner(
        n_trials=n_trials, n_jobs=3, min_samples_range=(1, 1)
    ),
)

## 4. Train the model

In [36]:
from pytorch_lightning.loggers import WandbLogger

wl = WandbLogger(
    project="lst_oc",
    group="first",
    offline=True,
)

  rank_zero_warn(


In [37]:
trainer = Trainer(
    accelerator="gpu", log_every_n_steps=1, logger=wl, enable_progress_bar=False
)
trainer.fit(model=lmodel, datamodule=dm)

  rank_zero_warn(
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[32m[21:28:26] INFO: DataLoader will load 150 graphs (out of 175 available).[0m
[36m[21:28:26] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/processed/0000.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/processed/0149.pt[0m
[32m[21:28:26] INFO: DataLoader will load 5 graphs (out of 175 available).[0m
[36m[21:28:26] DEBUG: First graph is /scratch/gpfs/IOJALVO/g

[3m            Validation epoch=1             [0m
┏━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━┳━━━━━━━━━┓
┃[1m [0m[1mMetric             [0m[1m [0m┃[1m [0m[1m  Value[0m[1m [0m┃[1m [0m[1m  Error[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━╇━━━━━━━━━┩
│[1;95m [0m[1;95mattractive         [0m[1;95m [0m│[1;95m [0m[1;95m0.00433[0m[1;95m [0m│[1;95m [0m[1;95m0.00087[0m[1;95m [0m│
│ attractive_weighted │ 0.00433 │ 0.00087 │
│ background          │ 0.81130 │ 0.00765 │
│ background_weighted │ 0.08113 │ 0.00077 │
│[1;95m [0m[1;95mrepulsive          [0m[1;95m [0m│[1;95m [0m[1;95m0.00115[0m[1;95m [0m│[1;95m [0m[1;95m0.00013[0m[1;95m [0m│
│ repulsive_weighted  │ 0.00115 │ 0.00013 │
│ total               │ 0.08661 │ 0.00002 │
└─────────────────────┴─────────┴─────────┘



  rank_zero_warn(
[36m[21:29:14 ClusterHP] DEBUG: Starting from params: {}[0m
[32m[21:29:14 ClusterHP] INFO: Starting hyperparameter scan for clustering[0m
[36m[21:29:22 ClusterHP] DEBUG: Evaluated {'eps': 0.6866929660849241, 'min_samples': 1}: 0.004317712790379806[0m
[36m[21:29:24 ClusterHP] DEBUG: Evaluated {'eps': 0.853287073120374, 'min_samples': 1}: 0.0027049188066643414[0m
[36m[21:29:26 ClusterHP] DEBUG: Evaluated {'eps': 0.977435769363506, 'min_samples': 1}: 0.0007317472616462487[0m
[36m[21:29:29 ClusterHP] DEBUG: Evaluated {'eps': 0.5933909030045587, 'min_samples': 1}: 0.005684136257668483[0m
[36m[21:29:31 ClusterHP] DEBUG: Evaluated {'eps': 0.623383450898183, 'min_samples': 1}: 0.006401511644990403[0m
[36m[21:29:34 ClusterHP] DEBUG: Evaluated {'eps': 0.675201361516184, 'min_samples': 1}: 0.004987103370599986[0m
[36m[21:29:34 ClusterHP] DEBUG: Evaluating all metrics for best clustering[0m
[36m[21:29:41 ClusterHP] DEBUG: Evaluating metrics took 7.304349 second

[3m                    Validation epoch=1                     [0m
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃[1m [0m[1mMetric                        [0m[1m [0m┃[1m [0m[1m     Value[0m[1m [0m┃[1m [0m[1m    Error[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ adjusted_rand                  │   -0.00919 │   0.00828 │
│[1;95m [0m[1;95mattractive                    [0m[1;95m [0m│[1;95m [0m[1;95m   0.00593[0m[1;95m [0m│[1;95m [0m[1;95m  0.00058[0m[1;95m [0m│
│ attractive_weighted            │    0.00593 │   0.00058 │
│ background                     │    0.73458 │   0.01528 │
│ background_weighted            │    0.07346 │   0.00153 │
│ best_dbscan_eps                │    0.62338 │       nan │
│ best_dbscan_min_samples        │    1.00000 │       nan │
│ completeness                   │    0.35017 │   0.01575 │
│ fowlkes_mallows                │    0.41323 │   0.06298 │
│ homogeneity                    │    0.2

[36m[21:30:29 ClusterHP] DEBUG: Starting from params: {'eps': 0.623383450898183, 'min_samples': 1}[0m
[32m[21:30:29 ClusterHP] INFO: Starting hyperparameter scan for clustering[0m
[32m[21:30:29 ClusterHP] INFO: Clustering hyperparameter scan & metric evaluation took 0.00 seconds[0m


[3m                    Validation epoch=2                     [0m
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━┳━━━━━━━━━━━┓
┃[1m [0m[1mMetric                        [0m[1m [0m┃[1m [0m[1m     Value[0m[1m [0m┃[1m [0m[1m    Error[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━╇━━━━━━━━━━━┩
│ adjusted_rand                  │   -0.00919 │   0.00828 │
│[1;95m [0m[1;95mattractive                    [0m[1;95m [0m│[1;95m [0m[1;95m   0.00656[0m[1;95m [0m│[1;95m [0m[1;95m  0.00050[0m[1;95m [0m│
│ attractive_weighted            │    0.00656 │   0.00050 │
│ background                     │    0.72970 │   0.01016 │
│ background_weighted            │    0.07297 │   0.00102 │
│ best_dbscan_eps                │    0.62338 │       nan │
│ best_dbscan_min_samples        │    1.00000 │       nan │
│ completeness                   │    0.35017 │   0.01575 │
│ fowlkes_mallows                │    0.41323 │   0.06298 │
│ homogeneity                    │    0.2

  rank_zero_warn("Detected KeyboardInterrupt, attempting graceful shutdown...")


In [39]:
print("test")

test


In [40]:
! pwd

/home/kl5675/Documents/23/git_sync/hpo/slurm
