# First object condensation model training

* **Requirements*: You need to have graphs constructed, e.g., with `010_build_graphs.ipynb`

In [1]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [2]:
graph_dir = Path.home() / "data" / "gnn_tracking" / "graphs"
assert graph_dir.is_dir()

In [3]:
datasets = {
    "train": TrackingDataset(graph_dir, stop=9),
    "val": TrackingDataset(graph_dir, start=9, stop=10),
}
loaders = get_loaders(datasets)

[32m[17:28:57] INFO: DataLoader will load 9 graphs (out of 10 available).[0m
[36m[17:28:57] DEBUG: First graph is /Users/fuchur/data/gnn_tracking/graphs/data21000_s5.pt, last graph is /Users/fuchur/data/gnn_tracking/graphs/data21008_s10.pt[0m
[32m[17:28:57] INFO: DataLoader will load 1 graphs (out of 10 available).[0m
[36m[17:28:57] DEBUG: First graph is /Users/fuchur/data/gnn_tracking/graphs/data21008_s24.pt, last graph is /Users/fuchur/data/gnn_tracking/graphs/data21008_s24.pt[0m
[36m[17:28:57] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x11154c4f0>, 'pin_memory': True}[0m
[36m[17:28:57] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True}[0m


In [4]:
loss_functions = {
    "edge": (EdgeWeightFocalLoss(gamma=5, alpha=0.95), 500.0),
    "potential": (PotentialLoss(q_min=0.01), {"attractive": 500, "repulsive": 5}),
    "background": (BackgroundLoss(sb=1), 0.05),
}

The values after the loss functions are the loss weights. The potential loss is a special case, because it returns a dictionary two values: `attractive` and `repulsive`. Therefore, there are also two loss weights.

In [8]:
model = GraphTCN(
    node_indim=datasets["train"].num_node_features,
    edge_indim=datasets["train"].num_edge_features,
    h_dim=10,
    e_dim=10,
    L_ec=5,
    L_hc=2,
    h_outdim=10,
    hidden_dim=128,
)

In [9]:
clustering_functions = {"dbscan": dbscan_scan}

In [10]:
trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    cluster_functions=clustering_functions,
)

[32m[17:30:33 TCNTrainer] INFO: Using device cpu[0m


In [11]:
trainer.train(epochs=2)

[32m[17:30:37 TCNTrainer] INFO: Epoch  1 (    0/9): edge_weighted=   0.91777, potential_attractive_weighted=   0.00905, potential_repulsive_weighted=1862.03873, background_weighted=   0.04993[0m
[32m[17:30:46 TCNTrainer] INFO: Training for epoch 1 took 10.92 seconds[0m
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = ret.dtype.type(ret / rcount)
[36m[17:30:53 ClusterHP] DEBUG: Starting from params: {'eps': 0.95, 'min_samples': 1}[0m
[32m[17:30:53 ClusterHP] INFO: Starting hyperparameter scan for clustering[0m
[36m[17:30:54 ClusterHP] DEBUG: Evaluating all metrics for best clustering[0m
[36m[17:30:54 ClusterHP] DEBUG: Evaluating metrics took 0.027590 seconds: Clustering time: 0.004927, total metric eval: 0.022464, individual: v_measure: 0.0008469170000040549, homogeneity: 0.0007916249999908587, completeness: 0.0007492499999983693, trk: 0.018884791999994377, adjusted_rand: 0.000683875000