# One shot object condensation

This notebook shows how you can implement a model that directly goes from point cloud data to object condensation.

In [1]:
from pytorch_lightning.core.mixins import HyperparametersMixin
from torch import nn
from torch_geometric.nn.conv import GravNetConv
from torch_geometric.data import Data
from pytorch_lightning import Trainer

from gnn_tracking.metrics.losses import PotentialLoss, BackgroundLoss
import torch
from functools import partial
from gnn_tracking.training.tc import TCModule
from gnn_tracking.utils.loading import TrackingDataModule

## 1. Configure data

In [2]:
dm = TrackingDataModule(
    train=dict(
        dirs=["/Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/"],
        stop=5,
    ),
    val=dict(
        dirs=["/Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/"],
        start=5,
        stop=10,
    ),
    # could also configure a 'test' set here
)

## 2. Write a model

In [3]:
class DemoGravNet(nn.Module, HyperparametersMixin):
    def __init__(self, in_dim: int = 14, depth: int = 1, k: int = 2):
        super().__init__()
        self.save_hyperparameters()
        layers = [
            GravNetConv(
                in_channels=in_dim,
                out_channels=in_dim,
                space_dimensions=3,
                propagate_dimensions=3,
                k=k,
            )
            for _ in range(depth)
        ]
        self._embedding = nn.Sequential(*layers)
        self._beta = nn.Sequential(
            nn.Linear(in_dim, 1),
            nn.Sigmoid(),
        )

    def forward(self, data: Data):
        latent = self._embedding(data.x)
        beta = self._beta(latent).squeeze()
        eps = 1e-6
        beta = beta.clamp(eps, 1 - eps)
        return {
            "B": beta,
            "H": latent,
        }

In [4]:
model = DemoGravNet()

## 3. Configure loss functions and weights

In [5]:
# The loss functions can be memory hungry. Here we override `data_preproc` to place a tighter pt cut on
# the data to easy computation (since this is just a demo).
# You can use the original TCModule class if you have enough GPU/CPU memory.
class MyTCModule(TCModule):
    def data_preproc(self, data: Data):
        mask = data.pt > 4
        data = data.subgraph(mask)
        return data

In [6]:
from gnn_tracking.postprocessing.dbscanscanner import DBSCANHyperParamScanner

# TC for track condensation
lmodel = MyTCModule(
    model=model,
    potential_loss=PotentialLoss(
        radius_threshold=1.0,
    ),
    background_loss=BackgroundLoss(),
    lw_repulsive=2.0,
    lw_background=0.1,
    optimizer=partial(torch.optim.Adam, lr=1e-4),
    cluster_scanner=DBSCANHyperParamScanner(n_trials=5, n_jobs=1),
)

[36m[14:19:21] DEBUG: Got obj of type <class '__main__.DemoGravNet'>, assuming I have to save hyperparameters[0m
[36m[14:19:21] DEBUG: Saving hyperperameters {'class_path': '__main__.DemoGravNet', 'init_args': {'in_dim': 14, 'depth': 1, 'k': 2}}[0m
[36m[14:19:21] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.PotentialLoss'>, assuming I have to save hyperparameters[0m
[36m[14:19:21] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.PotentialLoss', 'init_args': {'q_min': 0.01, 'radius_threshold': 1.0, 'attr_pt_thld': 0.9}}[0m
[36m[14:19:21] DEBUG: Got obj of type <class 'gnn_tracking.metrics.losses.BackgroundLoss'>, assuming I have to save hyperparameters[0m
[36m[14:19:21] DEBUG: Saving hyperperameters {'class_path': 'gnn_tracking.metrics.losses.BackgroundLoss', 'init_args': {'sb': 0.1}}[0m
[36m[14:19:21] DEBUG: Got obj of type <class 'gnn_tracking.postprocessing.dbscanscanner.DBSCANHyperParamScanner'>, assuming I have to save hyperparame

## 4. Train the model

In [7]:
trainer = Trainer(max_epochs=1, accelerator="cpu", log_every_n_steps=1)
trainer.fit(model=lmodel, datamodule=dm)

GPU available: True (mps), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(
[32m[14:19:21] INFO: DataLoader will load 5 graphs (out of 90 available).[0m
[36m[14:19:21] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21025_s0.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21053_s0.pt[0m
[32m[14:19:21] INFO: DataLoader will load 5 graphs (out of 90 available).[0m
[36m[14:19:21] DEBUG: First graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21058_s0.pt, last graph is /Users/fuchur/tmp/truth_cut_graphs_for_gsoc/part_1_0/data21094_s0.pt[0m

  | Name            | Type           | Params
---------------------------------------------------
0 | model           | DemoGravNet    | 399   
1 | potential_loss  | PotentialLoss  | 0     
2 | background_loss | BackgroundLoss | 0     
---------------------------------------------

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(


  rank_zero_warn(


Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

[36m[14:19:37 ClusterHP] DEBUG: Starting from params: {}[0m
[32m[14:19:37 ClusterHP] INFO: Starting hyperparameter scan for clustering[0m
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  return a / b
  r

`Trainer.fit` stopped: `max_epochs=1` reached.
