In [None]:
import numpy as np
import torch

from gnn_tracking.utils.loading import TrackingDataset
from gnn_tracking.utils.loading import get_loaders
from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.models.graph_construction import GraphConstructionFCNN
from gnn_tracking_hpo.trainable import GCTrainer

%load_ext autoreload
%autoreload 2

In [11]:
model = GraphConstructionFCNN(
    in_dim=7,
    hidden_dim=256,
    out_dim=12,
    depth=6,
    beta=0.4
)

In [12]:
from gnn_tracking_hpo.trainable import MetricLearningGraphConstruction

# model = MetricLearningGraphConstruction(
#     node_indim=7,
#     hidden_dim=256,
#     h_outdim=12,
#     L_gc=6,
# )

In [13]:


ds = TrackingDataset(
    [
        f"/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v3/part_{i}"
        for i in range(1, 9)
    ]
)
val_ds = TrackingDataset(
    "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v3/part_9",
    stop=5,
)
loaders = get_loaders({"train": ds, "val": val_ds}, batch_size=1, max_sample_size=100)

[32m[15:49:14] INFO: DataLoader will load 7743 graphs (out of 7743 available).[0m
[36m[15:49:14] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v3/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v3/part_8/data28999_s0.pt[0m
[32m[15:49:14] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[15:49:14] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v3/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v3/part_9/data29004_s0.pt[0m
[36m[15:49:14] DEBUG: Parameters for data loader 'train': {'batch_size': 1, 'num_workers': 1, 'sampler': <torch.utils.data.sampler.RandomSampler object at 0x15214152e7a0>, 'pin_memory': True, 'shuffle': None}[0m
[36m[15:49:14] DEBUG: Parameters for data loader 'val': {'batch_size': 1, 'num_workers': 1, 'sampler': None, 'pin_memory': True, '

In [23]:

losses = {
    "potential": (
        GraphConstructionHingeEmbeddingLoss(r_emb=1),
        {"attractive": 1., "repulsive": 10.}
    )
}

In [24]:

trainer = GCTrainer(model=model, loss_functions=losses, loaders=loaders, lr=1e-3)

[32m[15:52:53 TCNTrainer] INFO: Using device cuda[0m


In [25]:
trainer.loss_functions=losses

In [26]:
import collections


def get_loss_avg(trainer):
    self = trainer
    loader = self.val_loader
    assert loader is not None
    losses = collections.defaultdict(list)
    for _batch_idx, data in enumerate(loader):
        data = data.to(self.device)  # noqa: PLW2901
        model_output = self.evaluate_model(
            data,
            mask_pids_reco=False,
        )
        batch_loss, these_batch_losses, loss_weights = self.get_batch_losses(
            model_output
        )
        for key, value in these_batch_losses.items():
            losses[key].append(value.item())
        losses["total"].append(batch_loss.item())
    print({key: np.mean(value) for key, value in losses.items()})

In [27]:
get_loss_avg(trainer)

  storage = elem.storage()._new_shared(numel)


{'potential_attractive': 0.0036825352348387242, 'potential_repulsive': 126.6882110595703, 'total': 1266.8857788085938}


In [28]:
for _ in range(100):
    trainer.train_step()

  storage = elem.storage()._new_shared(numel)
[36m[15:52:57 TCNTrainer] DEBUG: Epoch 0 (    0/100): Total= 818.31104, potential_attractive=   0.00231, potential_repulsive= 818.30872 (weighted)[0m
[36m[15:52:58 TCNTrainer] DEBUG: Epoch 0 (   10/100): Total= 418.91165, potential_attractive=   0.12046, potential_repulsive= 418.79120 (weighted)[0m
[36m[15:52:59 TCNTrainer] DEBUG: Epoch 0 (   20/100): Total= 206.15721, potential_attractive=   0.16402, potential_repulsive= 205.99319 (weighted)[0m
[36m[15:53:00 TCNTrainer] DEBUG: Epoch 0 (   30/100): Total= 203.25117, potential_attractive=   0.45396, potential_repulsive= 202.79721 (weighted)[0m
[36m[15:53:01 TCNTrainer] DEBUG: Epoch 0 (   40/100): Total= 108.69900, potential_attractive=   0.76063, potential_repulsive= 107.93837 (weighted)[0m
[36m[15:53:02 TCNTrainer] DEBUG: Epoch 0 (   50/100): Total=  56.11842, potential_attractive=   1.18358, potential_repulsive=  54.93484 (weighted)[0m
[36m[15:53:03 TCNTrainer] DEBUG: Epoch 0 

KeyboardInterrupt: 