In [1]:
from pytorch_lightning.core.mixins import HyperparametersMixin
from torch_geometric.data import Data
from pytorch_lightning import Trainer

from gnn_tracking.metrics.losses import PotentialLoss, BackgroundLoss
import torch
from functools import partial
from gnn_tracking.training.tc import TCModule
from gnn_tracking.utils.loading import TrackingDataModule



## 1. Configure data

In [8]:
dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/first/datasets"
        ],
        stop=150,
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/first/datasets"
        ],
        start=150,
        stop=155,
    ),
    # could also configure a 'test' set here
)

In [9]:
from gnn_tracking.models.resin import ResIN
from torch import nn, Tensor
from gnn_tracking.models.track_condensation_networks import ModularGraphTCN


class LSGraphTCN(nn.Module, HyperparametersMixin):
    def __init__(
        self,
        *,
        node_indim: int,
        edge_indim: int,
        h_dim=5,
        e_dim=4,
        h_outdim=2,
        hidden_dim=40,
        L_hc=3,
        alpha_hc: float = 0.5,
    ):
        super().__init__()
        self.save_hyperparameters()
        hc_in = ResIN(
            node_dim=h_dim,
            edge_dim=e_dim,
            object_hidden_dim=hidden_dim,
            relational_hidden_dim=hidden_dim,
            alpha=alpha_hc,
            n_layers=L_hc,
        )
        self._gtcn = ModularGraphTCN(
            hc_in=hc_in,
            node_indim=node_indim,
            edge_indim=edge_indim,
            h_dim=h_dim,
            e_dim=e_dim,
            h_outdim=h_outdim,
            hidden_dim=hidden_dim,
        )

    def forward(
        self,
        data: Data,
    ) -> dict[str, Tensor]:
        return self._gtcn.forward(data=data)

In [10]:
model = LSGraphTCN(
    node_indim=7, edge_indim=3, h_dim=128, e_dim=128, h_outdim=12, L_hc=3
)

  rank_zero_warn(


## 3. Configure loss functions and weights

In [11]:
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner

# TC for track condensation
lmodel = TCModule(
    model=model,
    potential_loss=PotentialLoss(
        radius_threshold=1.0,
    ),
    background_loss=BackgroundLoss(),
    lw_repulsive=2.0,
    lw_background=0.1,
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    cluster_scanner=DBSCANHyperParamScanner(n_trials=5, n_jobs=1),
    # preproc=PtCut(),
)

[36m[21:33:28] DEBUG: Got obj of type <class '__main__.LSGraphTCN'>, assuming I have to save hyperparameters[0m
[36m[21:33:28] DEBUG: Saving hyperperameters {'class_path': '__main__.LSGraphTCN', 'init_args': {'node_indim': 7, 'edge_indim': 3, 'h_dim': 128, 'e_dim': 128, 'h_outdim': 12, 'hidden_dim': 40, 'L_hc': 3, 'alpha_hc': 0.5}}[0m
[36m[21:33:28] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.PotentialLoss'>, assuming I have to save hyperparameters[0m
[36m[21:33:28] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.PotentialLoss', 'init_args': {'q_min': 0.01, 'radius_threshold': 1.0, 'attr_pt_thld': 0.9}}[0m
[36m[21:33:28] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.BackgroundLoss'>, assuming I have to save hyperparameters[0m
[36m[21:33:28] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.BackgroundLoss', 'init_args': {'sb': 0.1}}[0m
[36m[21:33:28] DEBUG: Got obj of type <class 'gnn_tracking.pos

## 4. Train the model

In [12]:
trainer = Trainer(max_epochs=1, accelerator="cpu", log_every_n_steps=1)
trainer.fit(model=lmodel, datamodule=dm)

GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m[21:33:29] INFO: DataLoader will load 150 graphs (out of 175 available).[0m
[36m[21:33:29] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/first/datasets/0000.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/first/datasets/0149.pt[0m
[32m[21:33:29] INFO: DataLoader will load 5 graphs (out of 175 available).[0m
[36m[21:33:29] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/first/datasets/0150.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/lst_data_v0/first/datasets/0154.pt[0m

  | Name            | Type           | Params
---------------------------------------------------
0 | model           | LSGraphTCN     | 143 K 
1 | potential_loss  | PotentialLoss  | 0     
2 | background_loss | Back

Sanity Checking: 0it [00:00, ?it/s]

  storage = elem.storage()._new_shared(numel)


AttributeError: 'GlobalStorage' object has no attribute 'pt'