# One shot object condensation

This notebook shows how you can implement a model that directly goes from point cloud data to object condensation.

In [63]:
from pathlib import Path
from pytorch_lightning.core.mixins.hparams_mixin import HyperparametersMixin
from torch import nn
from torch_geometric.nn.conv import GravNetConv
from torch_geometric.data import Data
from pytorch_lightning import Trainer

from gnn_tracking.metrics.losses.oc import CondensationLossTiger
import torch
from functools import partial

from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.training.tc import TCModule
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.utils.versioning import assert_version_geq

assert_version_geq("23.12.0")

## 1. Configure data

In [64]:
data_dir = (
    Path.cwd().resolve().parent.parent / "test-data" / "data" / "point_clouds" / "v8"
)
assert data_dir.is_dir()

In [65]:
dm = TrackingDataModule(
    train=dict(
        dirs=[data_dir],
        stop=1,
    ),
    val=dict(
        dirs=[data_dir],
        start=1,
        stop=2,
    ),
    identifier="point_clouds_v8"
    # could also configure a 'test' set here
)

## 2. Write a model

In [66]:
class DemoGravNet(nn.Module, HyperparametersMixin):
    def __init__(self, in_dim: int = 14, depth: int = 1, k: int = 2):
        super().__init__()
        self.save_hyperparameters()
        layers = [
            GravNetConv(
                in_channels=in_dim,
                out_channels=in_dim,
                space_dimensions=3,
                propagate_dimensions=3,
                k=k,
            )
            for _ in range(depth)
        ]
        self._embedding = nn.Sequential(*layers)
        self._beta = nn.Sequential(
            nn.Linear(in_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, data: Data):
        latent = self._embedding(data.x)
        beta = self._beta(latent).squeeze()
        eps = 1e-6
        beta = beta.clamp(eps, 1 - eps)
        return {
            "B": beta,
            "H": latent,
        }

In [67]:
model = DemoGravNet()

## 3. Configure loss functions and weights

In [68]:
# The loss functions can be memory hungry. Here we override `data_preproc` to place a tighter pt cut on
# the data to easy computation (since this is just a demo).
class PtCut(HyperparametersMixin):
    def __call__(self, data: Data):
        mask = data.pt > 4
        data = data.subgraph(mask)
        return data

In [71]:
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner

# TC for track condensation
lmodel = TCModule(
    model=model,
    loss_fct=CondensationLossTiger(
        lw_repulsive=2.0,
    ),
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    cluster_scanner=DBSCANHyperParamScanner(n_trials=5, n_jobs=1),
    preproc=PtCut(),
)

## 4. Train the model

In [72]:
trainer = Trainer(
    max_epochs=1,
    accelerator="cpu",
    log_every_n_steps=1,
    callbacks=[PrintValidationMetrics()],
)
trainer.fit(model=lmodel, datamodule=dm)

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m[15:52:19] INFO: DataLoader will load 1 graphs (out of 2 available).[0m
[36m[15:52:19] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt, last graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21000_s0.pt[0m
[32m[15:52:19] INFO: DataLoader will load 1 graphs (out of 2 available).[0m
[36m[15:52:19] DEBUG: First graph is /home/kl5675/Documents/23/git_sync/test-data/data/point_clouds/v8/data21001_s0.p

Sanity Checking: |                                                                                                                                                                                                             | 0/? [00:00<?, ?it/s]

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|                                                                                                                                                                                            | 0/1 [00:00<?, ?it/s]

No CUDA runtime is found, using CUDA_HOME='/scratch/gpfs/kl5675/micromamba/envs/gnn'


                                                                                                                                                                                                                                                     

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'train_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=1` in the `DataLoader` to improve performance.


Epoch 0:   0%|                                                                                                                                                                                                                 | 0/1 [00:00<?, ?it/s]



Epoch 0: 100%|█| 1/1 [00:10<00:00,  0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.




NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf fou


[3m                    Validation epoch=0                     [0m                                                                                                                                                                                  
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━┳━━━━━━━┓
┃[1m [0m[1mMetric                        [0m[1m [0m┃[1m [0m[1m         Value[0m[1m [0m┃[1m [0m[1mError[0m[1m [0m┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━╇━━━━━━━┩
│[1;95m [0m[1;95mattractive                    [0m[1;95m [0m│[1;95m [0m[1;95m55245000.00000[0m[1;95m [0m│[1;95m [0m[1;95m  nan[0m[1;95m [0m│
│ attractive_train               │ 93512936.00000 │   nan │
│ attractive_weighted            │ 55245000.00000 │   nan │
│ attractive_weighted_train      │ 93512936.00000 │   nan │
│ best_dbscan_eps                │        0.15979 │   nan │
│ best_dbscan_min_samples        │        4.00000 │   nan │
│ coward                         │        

`Trainer.fit` stopped: `max_epochs=1` reached.


Epoch 0: 100%|█| 1/1 [00:11<00:00,  0.09it/s, v_num=3, attractive_train=9.35e+7, repulsive_train=0.000, coward_train=0.0526, noise_train=nan.0, attractive_weighted_train=9.35e+7, repulsive_weighted_train=0.000, coward_weighted_train=0.000, noise
