# First object condensation model training

* **Requirements*: You need to have graphs constructed, e.g., with `010_build_graphs.ipynb`

In [10]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import train_test_val_split, get_loaders
from gnn_tracking.graph_construction.graph_builder import load_graphs

In [14]:
graph_dir = Path.home() / "data" / "gnn_tracking" / "graphs"
assert graph_dir.is_dir()

In [4]:
graph_dict = train_test_val_split(load_graphs(graph_dir, stop=10))
loaders = get_loaders(graph_dict)

[32mINFO: Loading 10 graphs (out of 3200 available).[0m
[36mDEBUG: Parameters for data loaders: {'batch_size': 1, 'num_workers': 1}[0m


In [5]:
loss_functions = {
    "edge": EdgeWeightFocalLoss(gamma=5, alpha=0.95),
    "potential": PotentialLoss(q_min=0.01),
    "background": BackgroundLoss(sb=1),
}

loss_weights = {
    "edge": 500,
    "potential_attractive": 500,
    "potential_repulsive": 5,
    "background": 0.05,
}

In [6]:
model = GraphTCN(
    node_indim=6,
    edge_indim=4,
    h_dim=10,
    e_dim=10,
    L_ec=5,
    L_hc=2,
    h_outdim=10,
    hidden_dim=128,
)

In [11]:
clustering_functions = {"dbscan": dbscan_scan}

In [12]:
trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    loss_weights=loss_weights,
    cluster_functions=clustering_functions,
)

[32mINFO: Using device cpu[0m


In [13]:
trainer.train(epochs=2)

[32mINFO: Epoch  1 (    0/8): background_weighted=   0.04997, edge_weighted=   0.48573, potential_attractive_weighted=   0.00032, potential_repulsive_weighted=   2.57460[0m
[32mINFO: Training for epoch 1 took 13.08 seconds[0m
[36mDEBUG: Starting from params: {'eps': 0.95, 'min_samples': 1}[0m
[32mINFO: Starting hyperparameter scan for clustering[0m
[32mINFO: Completed 100 trials, pruned 0 trials[0m
[36mDEBUG: Evaluating all metrics for best clustering[0m
[32mINFO: Evaluating all metrics took 0.36 seconds[0m
[32mINFO: Clustering hyperparameter scan & metric evaluation took 4.41 seconds[0m
[32mINFO: Test step for epoch 1 took 5.67 seconds[0m
[32mINFO: Results 1: 
+-----+-------------------------------------+-----------+
|     | Metric                              |     Value |
|     | F1_pt0.9                            |   0.17820 |
|     | F1_pt1.5                            |   0.12676 |
|     | FNR_pt0.9                           |   0.27976 |
|     | FNR_pt1.5    