# Edge Classifier Development

In [29]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [30]:
train_dir = Path('/') / "scratch" / "gpfs" / "IOJALVO" / "gnn-tracking" / "object_condensation" / "graphs_v1" / "part_4"
test_dir = Path('/') / "scratch" / "gpfs" / "IOJALVO" / "gnn-tracking" / "object_condensation" / "graphs_v1" / "part_5"
assert graph_dir.is_dir()

In [31]:
datasets = {
    "train": TrackingDataset(train_dir, stop=500),
    "val": TrackingDataset(test_dir, stop=500),
}
loaders = get_loaders(datasets)

[32m[20:03:09] INFO: DataLoader will load 500 graphs (out of 32000 available).[0m
[36m[20:03:09] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part_4/data24000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part_4/data24015_s26.pt[0m
[32m[20:03:09] INFO: DataLoader will load 500 graphs (out of 32000 available).[0m
[36m[20:03:09] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part_5/data25000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/graphs_v1/part_5/data25015_s26.pt[0m
[36m[20:03:09] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x14ba4745a200>, 'pin_memory': True}[0m
[36m[20:03:09] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True}[0m


In [56]:
import torch.nn
from torch.nn import Module, Sequential, Linear, ReLU, Sigmoid

class EC(Module):
    def __init__(
        self, node_in_dim: int, edge_in_dim: int, hidden_dim: int
    ):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.edge_in_dim = edge_in_dim
        self.hidden_dim = hidden_dim
        
        self.fcnn = Sequential(
            Linear(edge_in_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, 1),
            Sigmoid(),
        )
        
    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        return {"W": self.fcnn(edge_attr).squeeze()}
    
model = EC(7, 4, 128)
model(next(iter(loaders['train'])))

{'W': tensor([0.5062, 0.5061, 0.5062,  ..., 0.5065, 0.5066, 0.5065],
        grad_fn=<SqueezeBackward0>)}

In [57]:
loss_functions = {
    "edge": (EdgeWeightFocalLoss(gamma=5, alpha=0.95), 500.0),
}

clustering_functions = {
}

trainer = TCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    cluster_functions=clustering_functions,
)
trainer.train(epochs=2)

[32m[20:14:15TCNTrainer] INFO: Using device cuda[0m
[32m[20:14:15] INFO: Using device cuda[0m
[32m[20:14:15TCNTrainer] INFO: Epoch 1 (    0/500): Total=   0.76871, edge=   0.76871 (weighted)[0m
[32m[20:14:15] INFO: Epoch 1 (    0/500): Total=   0.76871, edge=   0.76871 (weighted)[0m
[32m[20:14:18TCNTrainer] INFO: Epoch 1 (   10/500): Total=   0.55160, edge=   0.55160 (weighted)[0m
[32m[20:14:18] INFO: Epoch 1 (   10/500): Total=   0.55160, edge=   0.55160 (weighted)[0m
[32m[20:14:19TCNTrainer] INFO: Epoch 1 (   20/500): Total=   0.52931, edge=   0.52931 (weighted)[0m
[32m[20:14:19] INFO: Epoch 1 (   20/500): Total=   0.52931, edge=   0.52931 (weighted)[0m
[32m[20:14:19TCNTrainer] INFO: Epoch 1 (   30/500): Total=   0.41746, edge=   0.41746 (weighted)[0m
[32m[20:14:19] INFO: Epoch 1 (   30/500): Total=   0.41746, edge=   0.41746 (weighted)[0m
[32m[20:14:19TCNTrainer] INFO: Epoch 1 (   40/500): Total=   0.45907, edge=   0.45907 (weighted)[0m
[32m[20:14:19] INFO: Ep

[32m[20:14:27TCNTrainer] INFO: Epoch 1 (  420/500): Total=   0.33317, edge=   0.33317 (weighted)[0m
[32m[20:14:27] INFO: Epoch 1 (  420/500): Total=   0.33317, edge=   0.33317 (weighted)[0m
[32m[20:14:27TCNTrainer] INFO: Epoch 1 (  430/500): Total=   0.34867, edge=   0.34867 (weighted)[0m
[32m[20:14:27] INFO: Epoch 1 (  430/500): Total=   0.34867, edge=   0.34867 (weighted)[0m
[32m[20:14:27TCNTrainer] INFO: Epoch 1 (  440/500): Total=   0.32580, edge=   0.32580 (weighted)[0m
[32m[20:14:27] INFO: Epoch 1 (  440/500): Total=   0.32580, edge=   0.32580 (weighted)[0m
[32m[20:14:27TCNTrainer] INFO: Epoch 1 (  450/500): Total=   0.34651, edge=   0.34651 (weighted)[0m
[32m[20:14:27] INFO: Epoch 1 (  450/500): Total=   0.34651, edge=   0.34651 (weighted)[0m
[32m[20:14:28TCNTrainer] INFO: Epoch 1 (  460/500): Total=   0.38244, edge=   0.38244 (weighted)[0m
[32m[20:14:28] INFO: Epoch 1 (  460/500): Total=   0.38244, edge=   0.38244 (weighted)[0m
[32m[20:14:28TCNTrainer] INFO

[32m[20:16:04TCNTrainer] INFO: Epoch 2 (    0/500): Total=   0.31109, edge=   0.31109 (weighted)[0m
[32m[20:16:04] INFO: Epoch 2 (    0/500): Total=   0.31109, edge=   0.31109 (weighted)[0m
[32m[20:16:04TCNTrainer] INFO: Epoch 2 (   10/500): Total=   0.38763, edge=   0.38763 (weighted)[0m
[32m[20:16:04] INFO: Epoch 2 (   10/500): Total=   0.38763, edge=   0.38763 (weighted)[0m
[32m[20:16:04TCNTrainer] INFO: Epoch 2 (   20/500): Total=   0.41895, edge=   0.41895 (weighted)[0m
[32m[20:16:04] INFO: Epoch 2 (   20/500): Total=   0.41895, edge=   0.41895 (weighted)[0m
[32m[20:16:04TCNTrainer] INFO: Epoch 2 (   30/500): Total=   0.42117, edge=   0.42117 (weighted)[0m
[32m[20:16:04] INFO: Epoch 2 (   30/500): Total=   0.42117, edge=   0.42117 (weighted)[0m
[32m[20:16:04TCNTrainer] INFO: Epoch 2 (   40/500): Total=   0.33819, edge=   0.33819 (weighted)[0m
[32m[20:16:04] INFO: Epoch 2 (   40/500): Total=   0.33819, edge=   0.33819 (weighted)[0m
[32m[20:16:05TCNTrainer] INFO

[32m[20:16:12] INFO: Epoch 2 (  420/500): Total=   0.34228, edge=   0.34228 (weighted)[0m
[32m[20:16:12TCNTrainer] INFO: Epoch 2 (  430/500): Total=   0.38258, edge=   0.38258 (weighted)[0m
[32m[20:16:12] INFO: Epoch 2 (  430/500): Total=   0.38258, edge=   0.38258 (weighted)[0m
[32m[20:16:12TCNTrainer] INFO: Epoch 2 (  440/500): Total=   0.30713, edge=   0.30713 (weighted)[0m
[32m[20:16:12] INFO: Epoch 2 (  440/500): Total=   0.30713, edge=   0.30713 (weighted)[0m
[32m[20:16:13TCNTrainer] INFO: Epoch 2 (  450/500): Total=   0.33832, edge=   0.33832 (weighted)[0m
[32m[20:16:13] INFO: Epoch 2 (  450/500): Total=   0.33832, edge=   0.33832 (weighted)[0m
[32m[20:16:13TCNTrainer] INFO: Epoch 2 (  460/500): Total=   0.33636, edge=   0.33636 (weighted)[0m
[32m[20:16:13] INFO: Epoch 2 (  460/500): Total=   0.33636, edge=   0.33636 (weighted)[0m
[32m[20:16:13TCNTrainer] INFO: Epoch 2 (  470/500): Total=   0.37093, edge=   0.37093 (weighted)[0m
[32m[20:16:13] INFO: Epoch 2 

[32m[20:17:48TCNTrainer] INFO: Saving checkpoint to 230421_201748_model.pt[0m
[32m[20:17:48] INFO: Saving checkpoint to 230421_201748_model.pt[0m
