# First object condensation model training

* **Requirements*: You need to have graphs constructed, e.g., with `010_build_graphs.ipynb`

In [1]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
graph_dir = Path.home() / "Desktop" / "jian-gnn-tracking-experiments" / "data" / "part_1_0"
assert graph_dir.is_dir()

In [3]:
datasets = {
    "train": TrackingDataset(graph_dir, stop=80),
    "val": TrackingDataset(graph_dir, start=80, stop=90),
}
loaders = get_loaders(datasets)

[32m[07:21:54] INFO: DataLoader will load 80 graphs (out of 90 available).[0m
[36m[07:21:54] DEBUG: First graph is C:\Users\Jian\Desktop\jian-gnn-tracking-experiments\data\part_1_0\data21025_s0.pt, last graph is C:\Users\Jian\Desktop\jian-gnn-tracking-experiments\data\part_1_0\data21878_s0.pt[0m
[32m[07:21:54] INFO: DataLoader will load 10 graphs (out of 90 available).[0m
[36m[07:21:54] DEBUG: First graph is C:\Users\Jian\Desktop\jian-gnn-tracking-experiments\data\part_1_0\data21887_s0.pt, last graph is C:\Users\Jian\Desktop\jian-gnn-tracking-experiments\data\part_1_0\data21997_s0.pt[0m
[36m[07:21:54] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x0000017867D27F50>, 'pin_memory': True, 'shuffle': None}[0m
[36m[07:21:54] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True, 'shuffle': False}[0m


In [4]:
loss_functions = {
    "edge": (EdgeWeightFocalLoss(gamma=5, alpha=0.95), 500.0),
    "potential": (PotentialLoss(q_min=0.01), {"attractive": 500, "repulsive": 5}),
    "background": (BackgroundLoss(sb=1), 0.05),
}

The values after the loss functions are the loss weights. The potential loss is a special case, because it returns a dictionary two values: `attractive` and `repulsive`. Therefore, there are also two loss weights.

In [5]:
import torch_geometric

In [6]:
model = torch_geometric.nn.conv.GravNetConv(in_channels=7, out_channels=7, space_dimensions=4, propagate_dimensions=16, k=16)

In [7]:
clustering_functions = {"dbscan": dbscan_scan}

In [8]:
trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    cluster_functions=clustering_functions,
)

[32m[07:22:02 TCNTrainer] INFO: Using device cuda[0m


In [9]:
trainer.train(epochs=2)

AttributeError: 'GlobalStorage' object has no attribute 'dim'