# First object condensation model training

* **Requirements*: You need to have graphs constructed, e.g., with `010_build_graphs.ipynb`

In [1]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [2]:
graph_dir = Path.home() / "data" / "gnn_tracking" / "graphs"
assert graph_dir.is_dir()

In [9]:
datasets = {
    "train": TrackingDataset(graph_dir, stop=9),
    "val": TrackingDataset(graph_dir, start=9, stop=10),
}
loaders = get_loaders(datasets)

[32m[18:54:32] INFO: DataLoader will load 9 graphs (out of 10 available).[0m
[36m[18:54:32] DEBUG: First graph is /Users/fuchur/data/gnn_tracking/graphs/data21000_s5.pt, last graph is /Users/fuchur/data/gnn_tracking/graphs/data21008_s10.pt[0m
[32m[18:54:32] INFO: DataLoader will load 1 graphs (out of 10 available).[0m
[36m[18:54:32] DEBUG: First graph is /Users/fuchur/data/gnn_tracking/graphs/data21008_s24.pt, last graph is /Users/fuchur/data/gnn_tracking/graphs/data21008_s24.pt[0m
[36m[18:54:32] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'shuffle': True, 'pin_memory': True}[0m
[36m[18:54:32] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'shuffle': False, 'pin_memory': True}[0m


In [4]:
loss_functions = {
    "edge": EdgeWeightFocalLoss(gamma=5, alpha=0.95),
    "potential": PotentialLoss(q_min=0.01),
    "background": BackgroundLoss(sb=1),
}

loss_weights = {
    "edge": 500,
    "potential_attractive": 500,
    "potential_repulsive": 5,
    "background": 0.05,
}

In [10]:
model = GraphTCN(
    node_indim=datasets["train"].num_node_features,
    edge_indim=datasets["train"].num_edge_features,
    h_dim=10,
    e_dim=10,
    L_ec=5,
    L_hc=2,
    h_outdim=10,
    hidden_dim=128,
)

In [6]:
clustering_functions = {"dbscan": dbscan_scan}

In [11]:
trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    loss_weights=loss_weights,
    cluster_functions=clustering_functions,
)

[32m[18:55:04 TCNTrainer] INFO: Using device cpu[0m


In [12]:
trainer.train(epochs=2)

[32m[18:55:06 TCNTrainer] INFO: Epoch  1 (    0/9): edge_weighted=   0.83609, potential_attractive_weighted=   0.00020, potential_repulsive_weighted=2477.93213, background_weighted=   0.05000[0m
[32m[18:55:08 TCNTrainer] INFO: Training for epoch 1 took 3.34 seconds[0m
  var = nanvar(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  return _methods._var(a, axis=axis, dtype=dtype, out=out, ddof=ddof,
  ret = ret.dtype.type(ret / rcount)
[36m[18:55:09 ClusterHP] DEBUG: Starting from params: {'eps': 0.95, 'min_samples': 1}[0m
[32m[18:55:09 ClusterHP] INFO: Starting hyperparameter scan for clustering[0m
[36m[18:55:10 ClusterHP] DEBUG: Evaluating all metrics for best clustering[0m
[36m[18:55:10 ClusterHP] DEBUG: Evaluating metrics took 0.019963 seconds: Clustering time: 0.005406, total metric eval: 0.014266, individual: v_measure: 0.0010732090000118433, homogeneity: 0.0008535000000051696, completeness: 0.0008143750000044747, trk: 0.010325541000000271, adjusted_rand: 0.0006892500000