# First object condensation model training

* **Requirements*: You need to have graphs constructed, e.g., with `010_build_graphs.ipynb`

In [1]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [2]:
graph_dir = Path("D:\Devdoot\Princeton RSE\dataset\graph constructed")
print(graph_dir)
assert graph_dir.is_dir()

D:\Devdoot\Princeton RSE\dataset\graph constructed


In [3]:
datasets = {
    "train": TrackingDataset(graph_dir, stop=810),
    "val": TrackingDataset(graph_dir, start=810, stop=900),
}
loaders = get_loaders(datasets, batch_size=1)

[32m[20:30:01] INFO: DataLoader will load 810 graphs (out of 900 available).[0m
[36m[20:30:01] DEBUG: First graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21000_s0.pt, last graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21909_s0.pt[0m
[32m[20:30:01] INFO: DataLoader will load 90 graphs (out of 900 available).[0m
[36m[20:30:01] DEBUG: First graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21910_s0.pt, last graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21999_s0.pt[0m
[36m[20:30:01] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x000001EEFFACBA00>, 'pin_memory': True, 'shuffle': None}[0m
[36m[20:30:01] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True, 'shuffle': False}[0m


In [4]:
loss_functions = {
    "edge": (EdgeWeightFocalLoss(gamma=5, alpha=0.95), 500.0),
    "potential": (PotentialLoss(q_min=0.01), {"attractive": 500, "repulsive": 5}),
    "background": (BackgroundLoss(sb=1), 0.05),
}

The values after the loss functions are the loss weights. The potential loss is a special case, because it returns a dictionary two values: `attractive` and `repulsive`. Therefore, there are also two loss weights.

In [5]:
model = GraphTCN(
    node_indim=datasets["train"].num_node_features,
    edge_indim=datasets["train"].num_edge_features,
    h_dim=10,
    e_dim=10,
    L_ec=5,
    L_hc=2,
    h_outdim=10,
    hidden_dim=128,
)

In [6]:
clustering_functions = {"dbscan": dbscan_scan}

In [7]:
trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    cluster_functions=clustering_functions,
)

[32m[20:32:54 TCNTrainer] INFO: Using device cuda[0m


In [8]:
trainer.train(epochs=5)

[36m[20:33:36 TCNTrainer] DEBUG: Epoch 1 (    0/810): Total=4455.71240, edge=   2.01060, potential_attractive=   0.13194, potential_repulsive=4453.52002, background=   0.04986 (weighted)[0m
[36m[20:33:45 TCNTrainer] DEBUG: Epoch 1 (   10/810): Total=   6.64047, edge=   1.14008, potential_attractive=   0.01184, potential_repulsive=   5.43948, background=   0.04907 (weighted)[0m
[36m[20:33:48 TCNTrainer] DEBUG: Epoch 1 (   20/810): Total=   6.82506, edge=   0.87288, potential_attractive=   0.01487, potential_repulsive=   5.88735, background=   0.04996 (weighted)[0m
[36m[20:33:50 TCNTrainer] DEBUG: Epoch 1 (   30/810): Total=   4.30061, edge=   0.92406, potential_attractive=   0.64947, potential_repulsive=   2.67709, background=   0.05000 (weighted)[0m
[36m[20:33:52 TCNTrainer] DEBUG: Epoch 1 (   40/810): Total=   4.69129, edge=   0.99371, potential_attractive=   0.17424, potential_repulsive=   3.47334, background=   0.05000 (weighted)[0m
[36m[20:33:54 TCNTrainer] DEBUG: Epoch 