# Making old trained models compatible with lightning

TC: https://wandb.ai/gnn_tracking/gnn_tracking/runs/04b2e3ce/
GC: https://wandb.ai/gnn_tracking/gnn_tracking_gc/runs/7dce6aff
EC: https://wandb.ai/gnn_tracking/gnn_tracking_ec/runs/ddff435e

In [1]:
import json

from gnn_tracking.models.graph_construction import GraphConstructionFCNN



In [30]:
from gnn_tracking.utils.loading import TrackingDataModule
dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/"
        ],
        # If you run into memory issues, reduce this
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/"
        ],
        stop=5
    ),
)

## GC

In [3]:
model = GraphConstructionFCNN(in_dim=14, out_dim=8, depth=6, hidden_dim=512, beta=0.4)


In [14]:
import torch

ckpt = torch.load("/home/kl5675/ray_results/gc-hinge-sq-sq-cells/GCTrainable_7dce6aff_24_val_batch_size=1,adam_amsgrad=False,adam_beta1=0.9000,adam_beta2=0.9990,adam_eps=0.0000,adam_weight_decay=_2023-06-08_13-32-02/checkpoint_000009/checkpoint.pt", map_location="cpu")
state_dct = ckpt["model_state_dict"]
state_dct = {f"_{key}": value for key, value in state_dct.items()}
model.load_state_dict(state_dct)

<All keys matched successfully>

In [27]:
from gnn_tracking.training.ml import MLModule
from functools import partial
from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss

lmodel = MLModule(
    model=model,
    loss_fct=GraphConstructionHingeEmbeddingLoss(max_num_neighbors=256, r_emb=1, attr_pt_thld=0.9),
    lw_repulsive=0.001953029788887701,  # loss weight, see below
    optimizer=partial(torch.optim.Adam, lr=1e-3),
)

In [31]:
dm.setup(stage="fit")

[32m[18:18:54] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[18:18:54] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[18:18:54] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[18:18:54] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29004_s0.pt[0m


In [35]:
from pytorch_lightning import Trainer

trainer = Trainer(accelerator="cpu", max_epochs=0)

  rank_zero_warn(
GPU available: False, used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [36]:
trainer.fit(lmodel, dm)

  rank_zero_warn(
[32m[18:22:36] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[18:22:36] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[18:22:36] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[18:22:36] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29004_s0.pt[0m

  | Name     | Type                                | Params
-----------------------------------------------------------------
0 | model    | GraphConstructionFCNN               | 1.3 M 
1 | loss_fct | GraphConstructionHingeEmbeddingLoss | 0     
-----------------------------------------------------------------
1.3 M     Traina

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=0` reached.


In [37]:
trainer.save_checkpoint("gc-7dce6aff.ckpt")


In [39]:
lmodel.hparams

"loss_fct":     {'class_path': 'gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss', 'init_args': {'r_emb': 1, 'max_num_neighbors': 256, 'attr_pt_thld': 0.9, 'p_attr': 1, 'p_rep': 1}}
"lw_repulsive": 0.001953029788887701
"model":        {'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN', 'init_args': {'in_dim': 14, 'hidden_dim': 512, 'out_dim': 8, 'depth': 6, 'beta': 0.4}}
"preproc":      None