# Edge Classifier Development

In [None]:
from pathlib import Path

from gnn_tracking.postprocessing.dbscanscanner import dbscan_scan

from gnn_tracking.models.track_condensation_networks import GraphTCN
from gnn_tracking.training.tcn_trainer import TCNTrainer
from gnn_tracking.metrics.losses import (
    EdgeWeightFocalLoss,
    PotentialLoss,
    BackgroundLoss,
)
from gnn_tracking.utils.loading import get_loaders, TrackingDataset

In [None]:
train_dir = (
    Path("/")
    / "scratch"
    / "gpfs"
    / "IOJALVO"
    / "gnn-tracking"
    / "object_condensation"
    / "graphs_v1"
    / "part_4"
)
test_dir = (
    Path("/")
    / "scratch"
    / "gpfs"
    / "IOJALVO"
    / "gnn-tracking"
    / "object_condensation"
    / "graphs_v1"
    / "part_5"
)
assert train_dir.is_dir()
assert test_dir.is_dir()

In [None]:
datasets = {
    "train": TrackingDataset(train_dir, stop=500),
    "val": TrackingDataset(test_dir, stop=100),
}
loaders = get_loaders(datasets)

In [None]:
import torch.nn
from torch.nn import Module, Sequential, Linear, ReLU, Sigmoid


class EC(Module):
    def __init__(self, node_in_dim: int, edge_in_dim: int, hidden_dim: int):
        super().__init__()
        self.node_in_dim = node_in_dim
        self.edge_in_dim = edge_in_dim
        self.hidden_dim = hidden_dim

        self.fcnn = Sequential(
            Linear(edge_in_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, hidden_dim),
            ReLU(),
            Linear(hidden_dim, 1),
            Sigmoid(),
        )

    def forward(self, data):
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        return {"W": self.fcnn(edge_attr).squeeze()}


model = EC(7, 4, 128)
model(next(iter(loaders["train"])))

In [None]:
loss_functions = {
    "edge": (EdgeWeightFocalLoss(gamma=5, alpha=0.95), 500.0),
}

clustering_functions = {}


class MyTCNTrainer(TCNTrainer):
    def _log_losses(self, *args, **kwargs):
        pass

    def printed_results_filter(self, key):
        return key in ["max_mcc_pt0.9", "tpr_eq_tnr_pt0.9"]


trainer = MyTCNTrainer(
    model=model,
    loaders=loaders,
    loss_functions=loss_functions,
    lr=0.005,
    cluster_functions=clustering_functions,
)
trainer.train(epochs=10)