# First object condensation model training

* **Requirements*: You need to have graphs constructed, e.g., with `010_build_graphs.ipynb`

In [1]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [2]:
graph_dir = Path("D:\Devdoot\Princeton RSE\dataset\graph constructed")
print(graph_dir)
assert graph_dir.is_dir()

D:\Devdoot\Princeton RSE\dataset\graph constructed


In [3]:
datasets = {
    "train": TrackingDataset(graph_dir, stop=810),
    "val": TrackingDataset(graph_dir, start=810, stop=900),
}
loaders = get_loaders(datasets, batch_size=1)

[32m[19:01:53] INFO: DataLoader will load 810 graphs (out of 900 available).[0m
[36m[19:01:53] DEBUG: First graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21000_s0.pt, last graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21909_s0.pt[0m
[32m[19:01:53] INFO: DataLoader will load 90 graphs (out of 900 available).[0m
[36m[19:01:53] DEBUG: First graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21910_s0.pt, last graph is D:\Devdoot\Princeton RSE\dataset\graph constructed\data21999_s0.pt[0m
[36m[19:01:53] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x000001CA91443AF0>, 'pin_memory': True, 'shuffle': None}[0m
[36m[19:01:53] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True, 'shuffle': False}[0m


In [4]:
main_loss_functions = {
    "potential": (PotentialLoss(q_min=0.01), {"attractive": 1.0}),
}
constraint_loss_functions = {
    "edge": (EdgeWeightFocalLoss(gamma=5, alpha=0.95), 0.0002),
    "potential": (PotentialLoss(q_min=0.01), {"repulsive": 0.025}),
    "background": (BackgroundLoss(sb=1), 0.99),
}

The values after the loss functions are the loss weights. The potential loss is a special case, because it returns a dictionary two values: `attractive` and `repulsive`. Therefore, there are also two loss weights.

In [5]:
model = GraphTCN(
    node_indim=datasets["train"].num_node_features,
    edge_indim=datasets["train"].num_edge_features,
    h_dim=10,
    e_dim=10,
    L_ec=5,
    L_hc=2,
    h_outdim=10,
    hidden_dim=128,
)

In [6]:
clustering_functions = {"dbscan": dbscan_scan}

In [7]:
trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    main_loss_functions=main_loss_functions,
    constraint_loss_functions=constraint_loss_functions,
    lr=0.005,
    cluster_functions=clustering_functions,
)

[32m[19:01:53 TCNTrainer] INFO: Using device cuda[0m


In [8]:
trainer.train(epochs=1)

[36m[19:02:02 TCNTrainer] DEBUG: Epoch 1 (    0/810): Total=6294497.00000, potential_attractive=   0.00034, edge=   0.00443, potential_repulsive=1122.03186, background=   0.99727[0m
[36m[19:02:09 TCNTrainer] DEBUG: Epoch 1 (   10/810): Total=5864.24023, potential_attractive=5864.23779, edge=   0.00220, potential_repulsive=   0.00089, background=   0.99999[0m
[36m[19:02:12 TCNTrainer] DEBUG: Epoch 1 (   20/810): Total= 420.91931, potential_attractive= 420.91721, edge=   0.00166, potential_repulsive=   0.00272, background=   1.00000[0m
[36m[19:02:15 TCNTrainer] DEBUG: Epoch 1 (   30/810): Total=  31.02851, potential_attractive=  31.02709, edge=   0.00185, potential_repulsive=   0.00772, background=   1.00000[0m
[36m[19:02:18 TCNTrainer] DEBUG: Epoch 1 (   40/810): Total=  48.55421, potential_attractive=  48.55288, edge=   0.00171, potential_repulsive=   0.00988, background=   1.00000[0m
[36m[19:02:20 TCNTrainer] DEBUG: Epoch 1 (   50/810): Total= 446.45413, potential_attractiv