In [1]:
from gnn_tracking_hpo.util.paths import find_checkpoints, add_scripts_path

add_scripts_path()

In [2]:
from gnn_tracking_hpo.util.paths import get_config

checkpoint_path = find_checkpoints("ec-s9", "009d")[-1]
config = get_config("ec-s9", "009d")

In [3]:
config.update({"n_graphs_train": 1, "n_graphs_val": 100, "n_graphs_test": 1})

In [4]:
from tune_ec_sectorized import ECTrainable

In [5]:
trainable = ECTrainable(config)

[36mDEBUG: Got config
{'batch_size': 1,
 'focal_alpha': 0.7397820322968228,
 'focal_gamma': 4,
 'gnn_tracking_experiments_hash': '1c4385064cb8472d7070c92d38d2958ab96e7485',
 'gnn_tracking_hash': 'd1903e7319ef1dc27f2632d27212ce802d1273e2',
 'lr': 0.00010008424753725798,
 'lw_edge': 1.0,
 'm_L_ec': 7,
 'm_alpha_ec': 0.5082980468439962,
 'm_e_dim': 4,
 'm_h_dim': 5,
 'm_hidden_dim': 118,
 'm_interaction_edge_hidden_dim': 85,
 'm_interaction_node_hidden_dim': 94,
 'n_graphs_test': 1,
 'n_graphs_train': 1,
 'n_graphs_val': 100,
 'optimizer': 'adam',
 'scheduler': None,
 'sector': 9,
 'test': False,
 'training_pt_thld': 0.0,
 'training_without_noise': False,
 'training_without_non_reconstructable': False}[0m
[32mINFO: Loading data to cpu memory[0m
[32mINFO: Loading 102 graphs (out of 371 available).[0m
[36mDEBUG: Parameters for data loaders: {'batch_size': 1, 'num_workers': 12}[0m
[32mINFO: Using device cpu[0m


In [6]:
trainable.load_checkpoint(checkpoint_path, device="cpu")

In [7]:
ec = trainable.trainer.model

In [8]:
from gnn_tracking.models.edge_classifier import TrainableThldEC


In [9]:
for param in ec.parameters():
    param.requires_grad = False

In [20]:
from torch import Tensor
import torch.nn as nn
import torch
from torch_geometric.data import Data

class TrainableThldEC(nn.Module):
    def __init__(self, ec: nn.Module):
        """Edge classifier with a trainable threshold based on an existing
        classifier. If the parameters of the existing classifier are fixed, the
        threshold is the only trainable parameter.

        Args:
            ec: Edge classifier
        """
        super().__init__()
        self.ec = ec
        #: The threshold to use for the edge classifier
        self.threshold = nn.parameter.Parameter(torch.tensor(0.5), requires_grad=True)

    def _evaluate_ec(self, data) -> Tensor:
        r = self.ec(data)
        if isinstance(r, dict):
            return r["W"]
        else:
            return r

    def forward(self, data: Data) -> Tensor:
        # v = torch.minimum(self._evaluate_ec(data) + self.threshold, torch.tensor([1.])).float()  # type: ignore
        v = (self._evaluate_ec(data) > self.threshold).float()  # type: ignore
        # print(v, type(v), v.dtype)
        return v


In [21]:
ttec = TrainableThldEC(ec)

In [22]:
from gnn_tracking.metrics.losses import EdgeWeightBCELoss
from gnn_tracking.models.track_condensation_networks import PreTrainedECGraphTCN
from gnn_tracking.utils.dictionaries import subdict_with_prefix_stripped
from typing import Any
from gnn_tracking_hpo.trainable import TCNTrainable
from torch import nn


class PretrainedECTrainable(TCNTrainable):
    def __init__(self, config: dict[str, Any], ec: nn.Module, **kwargs):
        self.ec = ec
        super().__init__(config=config, **kwargs)

    def get_loss_functions(self) -> dict[str, Any]:
        return {
            "potential": self.get_potential_loss_function(),
            "background": self.get_background_loss_function(),
            "edge": self.get_edge_loss_function(),
        }

    def get_edge_loss_function(self):
        return EdgeWeightBCELoss()

    def get_cluster_functions(self) -> dict[str, Any]:
        return {}

    def get_model(self) -> nn.Module:
        # todo: add config for ec
        return PreTrainedECGraphTCN(self.ec,
            node_indim=6, edge_indim=4, **subdict_with_prefix_stripped(self.tc, "m_")
        )

In [23]:
from gnn_tracking_hpo.trainable import suggest_default_values
from gnn_tracking_hpo.config import get_metadata

config = {
    "lr": 0.0005655795153563859,
    "sb": 0.12120230680126508,
    "q_min": 0.3611768519294592,
    "m_L_hc": 3,
    "sector": 9,
    "m_e_dim": 5,
    "m_h_dim": 7,
    "optimizer": "adam",
    "scheduler": None,
    "batch_size": 1,
    "m_alpha_hc": .9,
    "m_h_outdim": 2,
    "attr_pt_thld": 0.5654455552047115,
    "m_hidden_dim": 116,
    "n_graphs_val": 69,
    "n_graphs_test": 1,
    "n_graphs_train": 300,
    "training_pt_thld": 0.9,
    "training_without_noise": True,
    "lw_potential_repulsive": 1e1,
    "lw_potential_attractive": 1e6,
    "lw_potential_background": 1e-2,
    "m_interaction_node_hidden_dim": 64,
    "m_interaction_edge_hidden_dim": 64,
    "repulsive_radius_threshold": 2.,
}
config.update(get_metadata())
suggest_default_values(config, None, perfect_ec=True)
# del config["m_L_ec"]
# del config["m_alpha_ec"]
# del config["m_feed_edge_weights"]
del config["m_ec_tpr"]
del config["m_ec_tnr"]

In [24]:
pt = PretrainedECTrainable(config, ec=ttec)

[36mDEBUG: Got config
{'attr_pt_thld': 0.5654455552047115,
 'batch_size': 1,
 'gnn_tracking_experiments_hash': '782c32667357aed7156545ac0fb9ea63aebdcd9c',
 'gnn_tracking_hash': '45ed36884f494d32894adb25e9d437910f1276da',
 'lr': 0.0005655795153563859,
 'lw_potential_attractive': 1000000.0,
 'lw_potential_background': 0.01,
 'lw_potential_repulsive': 10.0,
 'm_L_hc': 3,
 'm_alpha_hc': 0.9,
 'm_e_dim': 5,
 'm_h_dim': 7,
 'm_h_outdim': 2,
 'm_hidden_dim': 116,
 'm_interaction_edge_hidden_dim': 64,
 'm_interaction_node_hidden_dim': 64,
 'n_graphs_test': 1,
 'n_graphs_train': 300,
 'n_graphs_val': 69,
 'optimizer': 'adam',
 'q_min': 0.3611768519294592,
 'repulsive_radius_threshold': 2.0,
 'sb': 0.12120230680126508,
 'scheduler': None,
 'sector': 9,
 'test': False,
 'training_pt_thld': 0.9,
 'training_without_noise': True,
 'training_without_non_reconstructable': False}[0m
[32mINFO: Loading data to cpu memory[0m
[32mINFO: Loading 370 graphs (out of 371 available).[0m
[36mDEBUG: Paramet

In [29]:
pt.ec.threshold

Parameter containing:
tensor(0.5000, requires_grad=True)

In [26]:
# for p in pt.trainer.model.parameters():
#     try:
#         l = len(p)
#     except TypeError:
#         print(p)
#         continue
#     if l == 1:
#         print(p)

In [27]:
# import torch
#
# trainer = pt.trainer
# sel = 13
# with torch.no_grad():
#     loader = trainer.val_loader
#     for idx, data in enumerate(loader):
#         if idx < sel:
#             continue
#         model_output = trainer.evaluate_model(data, mask_pids_reco=False)
#         if idx == sel:
#             break
#
# mo = pt.trainer.evaluate_model(data)

In [28]:
for i in range(1):
    pt.trainer.train_step(max_batches=None)

[32mINFO: Epoch  0 (    0/300): background_weighted=   0.56950, edge_weighted=   4.12262, potential_attractive_weighted=   0.78956, potential_repulsive_weighted= 629.80835[0m
[32mINFO: Epoch  0 (   10/300): background_weighted=   0.63197, edge_weighted=   4.88311, potential_attractive_weighted=   0.05217, potential_repulsive_weighted= 400.09098[0m
[32mINFO: Epoch  0 (   20/300): background_weighted=   0.69406, edge_weighted=   3.98551, potential_attractive_weighted=   0.48126, potential_repulsive_weighted= 387.51053[0m
[32mINFO: Epoch  0 (   30/300): background_weighted=   0.78658, edge_weighted=   5.60773, potential_attractive_weighted=   4.15615, potential_repulsive_weighted= 157.75028[0m
[32mINFO: Epoch  0 (   40/300): background_weighted=   0.87497, edge_weighted=   4.43476, potential_attractive_weighted=   7.98965, potential_repulsive_weighted= 162.86625[0m
[32mINFO: Epoch  0 (   50/300): background_weighted=   0.93234, edge_weighted=   5.09066, potential_attractive_wei