In [12]:
from gnn_tracking_hpo.util.paths import find_checkpoints, add_scripts_path

add_scripts_path()

In [6]:
from gnn_tracking_hpo.util.paths import get_config

checkpoint_path = find_checkpoints("ec-s9", "009d")[-1]
config = get_config("ec-s9", "009d")

In [9]:
config.update({"n_graphs_train": 1, "n_graphs_val": 100, "n_graphs_test": 1})

In [15]:
from tune_ec_sectorized import ECTrainable

In [16]:
trainable = ECTrainable(config)

[36mDEBUG: Got config
{'batch_size': 1,
 'focal_alpha': 0.7397820322968228,
 'focal_gamma': 4,
 'gnn_tracking_experiments_hash': '1c4385064cb8472d7070c92d38d2958ab96e7485',
 'gnn_tracking_hash': 'd1903e7319ef1dc27f2632d27212ce802d1273e2',
 'lr': 0.00010008424753725798,
 'lw_edge': 1.0,
 'm_L_ec': 7,
 'm_alpha_ec': 0.5082980468439962,
 'm_e_dim': 4,
 'm_h_dim': 5,
 'm_hidden_dim': 118,
 'm_interaction_edge_hidden_dim': 85,
 'm_interaction_node_hidden_dim': 94,
 'n_graphs_test': 1,
 'n_graphs_train': 1,
 'n_graphs_val': 100,
 'optimizer': 'adam',
 'scheduler': None,
 'sector': 9,
 'test': False,
 'training_pt_thld': 0.0,
 'training_without_noise': False,
 'training_without_non_reconstructable': False}[0m
[32mINFO: Loading data to cpu memory[0m
[32mINFO: Loading 102 graphs (out of 371 available).[0m
[36mDEBUG: Parameters for data loaders: {'batch_size': 1, 'num_workers': 12}[0m
[32mINFO: Using device cpu[0m


In [17]:
ec = trainable.trainer.model

In [82]:

from torch_geometric.data import Data
from torch import Tensor
from torch import nn
import torch

class TrainableThldEC(nn.Module):
    def __init__(self, ec: nn.Module):
        """Edge classifier with a trainable threshold based on an existing
        classifier. If the parameters of the existing classifier are fixed, the
        threshold is the only trainable parameter.

        Args:
            ec: Edge classifier
        """
        super().__init__()
        self.ec = ec
        #: The threshold to use for the edge classifier
        self.threshold = nn.parameter.Parameter(torch.tensor(0.5), requires_grad=True)

    def _evaluate_ec(self, data) -> Tensor:
        r = self.ec(data)
        if isinstance(r, dict):
            return r["W"]
        else:
            return r

    def forward(self, data: Data) -> Tensor:
        return self._evaluate_ec(data) > self.threshold  # type: ignore


In [83]:
for param in ec.parameters():
    param.requires_grad = False

In [84]:
ttec = TrainableThldEC(ec)

In [85]:
from gnn_tracking.models.track_condensation_networks import ModularGraphTCN
from gnn_tracking.models.resin import ResIN


class PreTrainedECGraphTCN(nn.Module):
    def __init__(
        self,
        ec,
        *,
        node_indim: int,
        edge_indim: int,
        interaction_node_hidden_dim=5,
        interaction_edge_hidden_dim=4,
        h_dim=5,
        e_dim=4,
        h_outdim=2,
        hidden_dim=40,
        L_hc=3,
        alpha_hc: float = 0.5,
    ):
        """
        """
        super().__init__()
        hc_in = ResIN.identical_in_layers(
            node_indim=h_dim,
            edge_indim=e_dim,
            node_hidden_dim=interaction_node_hidden_dim,
            edge_hidden_dim=interaction_edge_hidden_dim,
            node_outdim=h_dim,
            edge_outdim=e_dim,
            object_hidden_dim=hidden_dim,
            relational_hidden_dim=hidden_dim,
            alpha=alpha_hc,
            n_layers=L_hc,
        )
        self._gtcn = ModularGraphTCN(
            ec=ec,
            hc_in=hc_in,
            node_indim=node_indim,
            edge_indim=edge_indim,
            h_dim=h_dim,
            e_dim=e_dim,
            h_outdim=h_outdim,
            hidden_dim=hidden_dim,
            L_hc=L_hc,
            feed_edge_weights=False,
        )

    def forward(
        self,
        data: Data,
    ) -> dict[str, Tensor]:
        return self._gtcn.forward(data=data)


In [92]:
from gnn_tracking.utils.dictionaries import subdict_with_prefix_stripped
from typing import Any
from gnn_tracking_hpo.trainable import TCNTrainable


class PretrainedECTrainable(TCNTrainable):
    def get_loss_functions(self) -> dict[str, Any]:
        return {
            "potential": self.get_potential_loss_function(),
            "background": self.get_background_loss_function(),
        }

    def get_model(self) -> nn.Module:
        # todo: add config for ec
        return PreTrainedECGraphTCN(ttec,
            node_indim=6, edge_indim=4, **subdict_with_prefix_stripped(self.tc, "m_")
        )

In [93]:
from gnn_tracking_hpo.trainable import suggest_default_values
from gnn_tracking_hpo.config import get_metadata

config = {
    "lr": 0.0005655795153563859,
    "sb": 0.12120230680126508,
    "q_min": 0.3611768519294592,
    "m_L_hc": 3,
    "sector": 9,
    "m_e_dim": 5,
    "m_h_dim": 7,
    "optimizer": "adam",
    "scheduler": None,
    "batch_size": 1,
    "m_alpha_hc": .9,
    "m_h_outdim": 2,
    "attr_pt_thld": 0.5654455552047115,
    "m_hidden_dim": 116,
    "n_graphs_val": 69,
    "n_graphs_test": 1,
    "n_graphs_train": 300,
    "training_pt_thld": 0.9,
    "training_without_noise": True,
    "lw_potential_repulsive": 1e1,
    "lw_potential_attractive": 1e6,
    "lw_potential_background": 1e-2,
    "m_interaction_node_hidden_dim": 64,
    "m_interaction_edge_hidden_dim": 64,
    "repulsive_radius_threshold": 2.,
}
config.update(get_metadata())
suggest_default_values(config, None, perfect_ec=True)
del config["m_ec_tpr"]
del config["m_ec_tnr"]



In [94]:
pt = PretrainedECTrainable(config)

[36mDEBUG: Got config
{'attr_pt_thld': 0.5654455552047115,
 'batch_size': 1,
 'gnn_tracking_experiments_hash': '1c4385064cb8472d7070c92d38d2958ab96e7485',
 'gnn_tracking_hash': 'd1903e7319ef1dc27f2632d27212ce802d1273e2',
 'lr': 0.0005655795153563859,
 'lw_potential_attractive': 1000000.0,
 'lw_potential_background': 0.01,
 'lw_potential_repulsive': 10.0,
 'm_L_hc': 3,
 'm_alpha_hc': 0.9,
 'm_e_dim': 5,
 'm_h_dim': 7,
 'm_h_outdim': 2,
 'm_hidden_dim': 116,
 'm_interaction_edge_hidden_dim': 64,
 'm_interaction_node_hidden_dim': 64,
 'n_graphs_test': 1,
 'n_graphs_train': 300,
 'n_graphs_val': 69,
 'optimizer': 'adam',
 'q_min': 0.3611768519294592,
 'repulsive_radius_threshold': 2.0,
 'sb': 0.12120230680126508,
 'scheduler': None,
 'sector': 9,
 'test': False,
 'training_pt_thld': 0.9,
 'training_without_noise': True,
 'training_without_non_reconstructable': False}[0m
[32mINFO: Loading data to cpu memory[0m
[32mINFO: Loading 370 graphs (out of 371 available).[0m
[36mDEBUG: Paramet

In [95]:
for i in range(2):
    pt.step()

[32mINFO: Epoch  1 (    0/300): background_weighted=   0.56950, potential_attractive_weighted=   0.79266, potential_repulsive_weighted= 629.81876[0m
[32mINFO: Epoch  1 (   10/300): background_weighted=   0.63192, potential_attractive_weighted=   0.05775, potential_repulsive_weighted= 400.27641[0m
[32mINFO: Epoch  1 (   20/300): background_weighted=   0.69248, potential_attractive_weighted=   0.28210, potential_repulsive_weighted= 391.10447[0m
[32mINFO: Epoch  1 (   30/300): background_weighted=   0.77643, potential_attractive_weighted=   3.33007, potential_repulsive_weighted= 163.50866[0m
[32mINFO: Epoch  1 (   40/300): background_weighted=   0.88735, potential_attractive_weighted=   9.87966, potential_repulsive_weighted= 160.21511[0m
[32mINFO: Epoch  1 (   50/300): background_weighted=   0.95290, potential_attractive_weighted=   4.37769, potential_repulsive_weighted= 173.75721[0m
[32mINFO: Epoch  1 (   60/300): background_weighted=   0.96772, potential_attractive_weighted