In [0]:
import torch
from gnn_tracking.models.resin import ResIN
from gnn_tracking.models.mlp import MLP
from torch_geometric.data import Data
from torch import nn, Tensor

In [10]:
class GNNEmbedding(nn.Module):
    def __init__(
        self,
        *,
        node_indim: int,
        edge_indim: int,
        interaction_node_dim: int = 5,
        interaction_edge_dim: int = 4,
        out_dim: int,
        hidden_dim: int | float = None,
        L_ec: int = 3,
        alpha: float = 0.5,
        residual_type="skip1",
        residual_kwargs: dict | None = None,
    ):
        """

        Args:
            node_indim: Node feature dim
            edge_indim: Edge feature dim
            interaction_node_dim: Node dimension for interaction networks.
                Defaults to 5 for backward compatibility, but this is probably
                not reasonable.
            interaction_edge_dim: Edge dimension of interaction networks
                Defaults to 4 for backward compatibility, but this is probably
                not reasonable.
            hidden_dim: width of hidden layers in all perceptrons (edge and node
                encoders, hidden dims for MLPs in object and relation networks). If
                None: choose as maximum of input/output dims for each MLP separately
            L_ec: message passing depth for edge classifier
            alpha: strength of residual connection for EC
            residual_type: type of residual connection for EC
            residual_kwargs: Keyword arguments passed to `ResIN`
        """
        super().__init__()
        if residual_kwargs is None:
            residual_kwargs = {}
        residual_kwargs["collect_hidden_edge_embeds"] = False
        self.relu = nn.ReLU()
        self.node_indim = node_indim
        self.edge_indim = edge_indim
        self.ec_node_encoder = MLP(
            node_indim, interaction_node_dim, hidden_dim=hidden_dim, L=2, bias=False
        )
        self.ec_edge_encoder = MLP(
            edge_indim, interaction_edge_dim, hidden_dim=hidden_dim, L=2, bias=False
        )
        self.ec_resin = ResIN(
            node_dim=interaction_node_dim,
            edge_dim=interaction_edge_dim,
            object_hidden_dim=hidden_dim,
            relational_hidden_dim=hidden_dim,
            alpha=alpha,
            n_layers=L_ec,
            residual_type=residual_type,
            residual_kwargs=residual_kwargs,
        )
        self.out_dim = out_dim
        self.latent_decoder = MLP(input_size=interaction_edge_dim, output_size=out_dim, hidden_dim=hidden_dim, L=3)

    def forward(
        self,
        data: Data,
    ) -> dict[str, Tensor]:
        """
        """
        x, edge_index, edge_attr = data.x, data.edge_index, data.edge_attr
        assert x.shape[1] == self.node_indim, x.shape
        assert edge_attr.shape[1] == self.edge_indim, edge_attr.shape
        h_ec = self.relu(self.ec_node_encoder(x))
        edge_attr_ec = self.relu(self.ec_edge_encoder(edge_attr))
        h_ec, _, _ = self.ec_resin(
            h_ec, edge_index, edge_attr_ec
        )
        latent = self.latent_decoder(h_ec)
        return {
            "H": latent
        }


In [4]:
from gnn_tracking_hpo.trainable import GCTrainable
from gnn_tracking_hpo.restore import restore_model

ml  = restore_model(GCTrainable, tune_dir="gc-hinge-sq-sq-cells", run_hash="7dce6aff")

[32m[15:41:07 HPO] INFO: Initializing pre-trained model[0m
[36m[15:41:07 HPO] DEBUG: Loading config from /home/kl5675/ray_results/gc-hinge-sq-sq-cells/GCTrainable_7dce6aff_24_val_batch_size=1,adam_amsgrad=False,adam_beta1=0.9000,adam_beta2=0.9990,adam_eps=0.0000,adam_weight_decay=_2023-06-08_13-32-02/params.json[0m
[32m[15:41:07 HPO] INFO: I'm running on a node with job ID=48416495[0m
[32m[15:41:07 HPO] INFO: The ID of my dispatcher is 0[0m
[36m[15:41:07 SlurmControl] DEBUG: Refreshing control config from /home/kl5675/ray_slurm_control.yaml[0m
[36m[15:41:07 HPO] DEBUG: Got config
┌───────────────────────────────┬──────────────────────────────────────────┐
│ _no_data                      │ True                                     │
│ _val_batch_size               │ 1                                        │
│ adam_amsgrad                  │ False                                    │
│ adam_beta1                    │ 0.9                                      │
│ adam_beta2     

In [29]:
ml_trainable = restore_model(GCTrainable, tune_dir="gc-hinge-sq-sq-cells", run_hash="7dce6aff", freeze=False)

[32m[15:58:51 HPO] INFO: Initializing pre-trained model[0m
[36m[15:58:51 HPO] DEBUG: Loading config from /home/kl5675/ray_results/gc-hinge-sq-sq-cells/GCTrainable_7dce6aff_24_val_batch_size=1,adam_amsgrad=False,adam_beta1=0.9000,adam_beta2=0.9990,adam_eps=0.0000,adam_weight_decay=_2023-06-08_13-32-02/params.json[0m
[32m[15:58:51 HPO] INFO: I'm running on a node with job ID=48416495[0m
[32m[15:58:51 HPO] INFO: The ID of my dispatcher is 0[0m
[36m[15:58:51 SlurmControl] DEBUG: Refreshing control config from /home/kl5675/ray_slurm_control.yaml[0m
[36m[15:58:51 HPO] DEBUG: Got config
┌───────────────────────────────┬──────────────────────────────────────────┐
│ _no_data                      │ True                                     │
│ _val_batch_size               │ 1                                        │
│ adam_amsgrad                  │ False                                    │
│ adam_beta1                    │ 0.9                                      │
│ adam_beta2     

In [33]:
from gnn_tracking.models.graph_construction import MLGraphConstruction
from gnn_tracking.training.tcn_trainer import TCNTrainer

gc = MLGraphConstruction(
    ml=ml_trainable,
    max_radius=0.8,
    max_num_neighbors=64,
    use_embedding_features=True,
    build_edge_features=True,
)

class MyTCNTrainer(TCNTrainer):
    def data_preproc(self, data: Data) -> Data:
        return gc(data)

In [34]:
from gnn_tracking.utils.loading import TrackingDataset, get_loaders

ds = TrackingDataset(
    [
        f"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_{i}"
        for i in range(1, 9)
    ]
)
val_ds = TrackingDataset(
    "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9", stop=5
)
loaders = get_loaders({"train": ds, "val": val_ds}, batch_size=1, max_sample_size=100)

[32m[16:02:04] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[16:02:04] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_8/data28999_s0.pt[0m
[32m[16:02:04] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[16:02:04] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29004_s0.pt[0m
[36m[16:02:04] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x14d04da7dde0>, 'pin_memory': True, 'shuffle': None}[0m
[36m[16:02:04] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True, '

In [35]:
from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss

losses = {
        "potential": (
            GraphConstructionHingeEmbeddingLoss(
                r_emb=1,
                max_num_neighbors=65,
                attr_pt_thld=0.9,
                p_attr=2,
                p_rep=2,
            ),
            {
                "attractive": 1,
                "repulsive": 1e-4,
            },
        )
    }

In [36]:
trainer = MyTCNTrainer(
    model=GNNEmbedding(
        node_indim=14+8,
        edge_indim=(14+8)*2,
        interaction_node_dim=128,
        interaction_edge_dim=128,
        hidden_dim=128,
        L_ec=3,
        alpha=0.35,
        out_dim=8,
    ),
    loss_functions=losses,
    loaders = loaders,
)

[32m[16:02:06 TCNTrainer] INFO: Using device cuda[0m


In [37]:
trainer.train(1)

  storage = elem.storage()._new_shared(numel)
[36m[16:02:09 TCNTrainer] DEBUG: Epoch 1 (    0/100): Total=   0.02466, potential_attractive=   0.01731, potential_repulsive=   0.00735 (weighted)[0m
[32m[16:02:19 TCNTrainer] INFO: Saving checkpoint to 230611_160219_model.pt[0m


KeyboardInterrupt: 

In [30]:
ml_trainer = TCNTrainer(
    model=ml_trainable,
    loss_functions=losses,
    loaders = loaders,
)

[32m[15:58:56 TCNTrainer] INFO: Using device cuda[0m


In [32]:
ml_trainer.train(1)

  storage = elem.storage()._new_shared(numel)
[36m[15:59:01 TCNTrainer] DEBUG: Epoch 1 (    0/100): Total=   0.11372, potential_attractive=   0.11006, potential_repulsive=   0.00366 (weighted)[0m
[36m[15:59:02 TCNTrainer] DEBUG: Epoch 1 (   10/100): Total=   0.18861, potential_attractive=   0.18473, potential_repulsive=   0.00388 (weighted)[0m
[36m[15:59:03 TCNTrainer] DEBUG: Epoch 1 (   20/100): Total=   0.04259, potential_attractive=   0.03918, potential_repulsive=   0.00342 (weighted)[0m
[36m[15:59:04 TCNTrainer] DEBUG: Epoch 1 (   30/100): Total=   0.02023, potential_attractive=   0.01734, potential_repulsive=   0.00290 (weighted)[0m
[36m[15:59:05 TCNTrainer] DEBUG: Epoch 1 (   40/100): Total=   0.01403, potential_attractive=   0.01095, potential_repulsive=   0.00308 (weighted)[0m
[36m[15:59:06 TCNTrainer] DEBUG: Epoch 1 (   50/100): Total=   0.01403, potential_attractive=   0.01075, potential_repulsive=   0.00328 (weighted)[0m
[36m[15:59:06 TCNTrainer] DEBUG: Epoch 1 