# Making old trained models compatible with lightning

TC: https://wandb.ai/gnn_tracking/gnn_tracking/runs/04b2e3ce/
GC: https://wandb.ai/gnn_tracking/gnn_tracking_gc/runs/7dce6aff
EC: https://wandb.ai/gnn_tracking/gnn_tracking_ec/runs/ddff435e

In [4]:
%load_ext autoreload
%autoreload 2

from gnn_tracking.models.graph_construction import GraphConstructionFCNN
from gnn_tracking.utils.loading import TrackingDataModule
import torch
from gnn_tracking.training.ml import MLModule
from functools import partial
from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss, \
    HaughtyFocalLoss
from pytorch_lightning import Trainer
from gnn_tracking.models.graph_construction import MLGraphConstruction
from gnn_tracking.training.ec import ECModule
from gnn_tracking.models.edge_classifier import ec_from_chkpt
import pprint
from gnn_tracking.models.graph_construction import MLGraphConstruction
from gnn_tracking.training.tc import TCModule

In [5]:
dm = TrackingDataModule(
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/"
        ],
        # If you run into memory issues, reduce this
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/"
        ],
        stop=1
    ),
)

In [6]:
dm.setup(stage="fit")

[32m[17:03:12] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[17:03:12] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[17:03:12] INFO: DataLoader will load 1 graphs (out of 1000 available).[0m
[36m[17:03:12] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt[0m


In [7]:
from pathlib import Path

model_exchange_path = Path(
    "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange"
)

## ML

In [5]:
gc = GraphConstructionFCNN(in_dim=14, out_dim=8, depth=6, hidden_dim=512, beta=0.4)


In [10]:

ckpt = torch.load("/home/kl5675/ray_results/gc-hinge-sq-sq-cells/GCTrainable_7dce6aff_24_val_batch_size=1,adam_amsgrad=False,adam_beta1=0.9000,adam_beta2=0.9990,adam_eps=0.0000,adam_weight_decay=_2023-06-08_13-32-02/checkpoint_000009/checkpoint.pt", map_location="cpu")
state_dct = ckpt["model_state_dict"]
state_dct = {f"_{key}": value for key, value in state_dct.items()}
gc.load_state_dict(state_dct)

<All keys matched successfully>

In [11]:

lmodel = MLModule(
    model=gc,
    loss_fct=GraphConstructionHingeEmbeddingLoss(max_num_neighbors=256, r_emb=1, attr_pt_thld=0.9),
    lw_repulsive=0.001953029788887701,  # loss weight, see below
    optimizer=partial(torch.optim.Adam, lr=1e-3),
)

In [12]:

trainer = Trainer(accelerator="cpu", max_epochs=0)

  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [13]:
trainer.fit(lmodel, dm)

  rank_zero_warn(
[32m[16:29:05] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[16:29:05] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[16:29:05] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[16:29:05] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29004_s0.pt[0m

  | Name     | Type                                | Params
-----------------------------------------------------------------
0 | model    | GraphConstructionFCNN               | 1.3 M 
1 | loss_fct | GraphConstructionHingeEmbeddingLoss | 0     
-----------------------------------------------------------------
1.3 M     Traina

Sanity Checking: 0it [00:00, ?it/s]

  rank_zero_warn(
  rank_zero_warn(
`Trainer.fit` stopped: `max_epochs=0` reached.


In [14]:
trainer.save_checkpoint(model_exchange_path / "gc-7dce6aff.ckpt")

In [16]:
lmodel.hparams

"loss_fct":     {'class_path': 'gnn_tracking.metrics.losses.GraphConstructionHingeEmbeddingLoss', 'init_args': {'r_emb': 1, 'max_num_neighbors': 256, 'attr_pt_thld': 0.9, 'p_attr': 1, 'p_rep': 1}}
"lw_repulsive": 0.001953029788887701
"model":        {'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN', 'init_args': {'in_dim': 14, 'hidden_dim': 512, 'out_dim': 8, 'depth': 6, 'beta': 0.4}}
"preproc":      None

## Trying to restore

In [17]:
lmodel = MLGraphConstruction.from_chkpt(
    ml_chkpt_path=model_exchange_path / "gc" / "gc-7dce6aff.ckpt",
    ml_freeze=False,
)

[36m[16:32:03] DEBUG: Getting class MLModule from module gnn_tracking.training.ml[0m
[36m[16:32:03] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/gc/gc-7dce6aff.ckpt[0m
[36m[16:32:03] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[16:32:03] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses[0m
[36m[16:32:03] DEBUG: Checkpoint loaded. Model ready to go.[0m


In [18]:
out = lmodel(dm.datasets["val"][0])

In [19]:
out

Data(x=[59357, 14], edge_index=[2, 7837097], edge_attr=[7837097, 28], y=[7837097], pt=[59357], particle_id=[59357], sector=[59357], reconstructable=[59357])

## EC

In [6]:
from gnn_tracking.models.edge_classifier import ECForGraphTCN

model = ECForGraphTCN(
    node_indim=14+8,
    edge_indim=(14+8)*2,
    interaction_node_dim=128,
    interaction_edge_dim=128,
    hidden_dim=128,
    L_ec=3,
    alpha=0.35,
    use_intermediate_edge_embeddings=False,
    use_node_embedding=True,
)


In [7]:
ckpt = torch.load("/home/kl5675/ray_results/ds-ef-7dce6aff/GCWithECTrainable_ddff435e_1_val_batch_size=1,adam_amsgrad=False,adam_beta1=0.9000,adam_beta2=0.9990,adam_eps=0.0000,adam_weight_d_2023-06-10_16-10-21/checkpoint_000083/checkpoint.pt", map_location="cpu")
state_dct = ckpt["model_state_dict"]
model.load_state_dict(state_dct)

<All keys matched successfully>

In [8]:


lmodel = ECModule(
    model=model,
    loss_fct=HaughtyFocalLoss(alpha=0.5, gamma=2),
    preproc=MLGraphConstruction.from_chkpt(
        ml_chkpt_path=model_exchange_path / "gc" / "gc-7dce6aff.ckpt",
        use_embedding_features=True,
    )
)

[36m[16:33:24] DEBUG: Getting class MLModule from module gnn_tracking.training.ml[0m
[36m[16:33:24] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/gc/gc-7dce6aff.ckpt[0m
[36m[16:33:24] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[16:33:24] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses[0m
[36m[16:33:24] DEBUG: Checkpoint loaded. Model ready to go.[0m


In [9]:
trainer = Trainer(accelerator="cpu", max_epochs=0, num_sanity_val_steps=0)

  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [10]:
trainer.fit(lmodel, dm)

  rank_zero_warn(
[32m[16:33:26] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[16:33:26] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[16:33:26] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[16:33:26] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29004_s0.pt[0m

  | Name     | Type                | Params
-------------------------------------------------
0 | model    | ECForGraphTCN       | 551 K 
1 | preproc  | MLGraphConstruction | 1.3 M 
2 | loss_fct | HaughtyFocalLoss    | 0     
-------------------------------------------------
551 K     Trainable params
1.3 M     Non-trainable p

In [11]:
trainer.save_checkpoint(model_exchange_path / "ec" / "ec-ddff435e.ckpt")

### Try restore

In [12]:
MLGraphConstruction.from_chkpt(
    ml_chkpt_path=model_exchange_path / "gc" / "gc-7dce6aff.ckpt",
    ec_chkpt_path=model_exchange_path / "ec" / "ec-ddff435e.ckpt",
    use_embedding_features=True,
    ec_thld=0.5,
)

[36m[16:33:51] DEBUG: Getting class MLModule from module gnn_tracking.training.ml[0m
[36m[16:33:51] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/gc/gc-7dce6aff.ckpt[0m
[36m[16:33:51] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[16:33:51] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses[0m
[36m[16:33:51] DEBUG: Checkpoint loaded. Model ready to go.[0m
[36m[16:33:51] DEBUG: Getting class ECModule from module gnn_tracking.training.ec[0m
[36m[16:33:51] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/ec/ec-ddff435e.ckpt[0m
[36m[16:33:52] DEBUG: Getting class ECForGraphTCN from module gnn_tracking.models.edge_classifier[0m
[36m[16:33:52] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction[0m
[36m[16:33:52] DEBUG: Getting class GraphC

MLGraphConstruction(
  (_ml): GraphConstructionFCNN(
    (_encoder): Linear(in_features=14, out_features=512, bias=False)
    (_decoder): Linear(in_features=512, out_features=8, bias=False)
    (_layers): ModuleList(
      (0-4): 5 x Linear(in_features=512, out_features=512, bias=False)
    )
  )
  (_ef): ECForGraphTCN(
    (relu): ReLU()
    (ec_node_encoder): MLP(
      (layers): ModuleList(
        (0): Linear(in_features=22, out_features=128, bias=False)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=False)
      )
    )
    (ec_edge_encoder): MLP(
      (layers): ModuleList(
        (0): Linear(in_features=44, out_features=128, bias=False)
        (1): ReLU()
        (2): Linear(in_features=128, out_features=128, bias=False)
      )
    )
    (ec_resin): ResIN(
      (network): Skip1ResidualNetwork(
        (layers): ModuleList(
          (0-2): 3 x InteractionNetwork()
        )
      )
    )
    (W): MLP(
      (layers): ModuleList(
        (0): 

In [13]:

ec_restored = ec_from_chkpt(model_exchange_path / "ec" / "ec-ddff435e.ckpt")

[36m[16:34:44] DEBUG: Getting class ECModule from module gnn_tracking.training.ec[0m
[36m[16:34:44] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/ec/ec-ddff435e.ckpt[0m
[36m[16:34:44] DEBUG: Getting class ECForGraphTCN from module gnn_tracking.models.edge_classifier[0m
[36m[16:34:44] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction[0m
[36m[16:34:44] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[16:34:44] DEBUG: Getting class HaughtyFocalLoss from module gnn_tracking.metrics.losses[0m
[36m[16:34:44] DEBUG: Checkpoint loaded. Model ready to go.[0m


## TC

In [9]:
from pytorch_lightning.core.mixins import HyperparametersMixin
from gnn_tracking.models.track_condensation_networks import GraphTCN, \
    PreTrainedECGraphTCN

# class MockEC():
#     latent_dim=(128, 128)

model = PreTrainedECGraphTCN(
    ec=ec_from_chkpt(model_exchange_path / "ec" / "ec-ddff435e.ckpt"),
    node_indim=22,
    edge_indim=44,
    h_dim=192,
    e_dim=192,
    h_outdim=12,
    hidden_dim=192,
    L_hc=3,
    feed_edge_weights=True,
    ec_threshold=0.2,
    mask_orphan_nodes=False,
    use_ec_embeddings_for_hc=True,

)

[36m[16:41:03] DEBUG: Getting class ECModule from module gnn_tracking.training.ec[0m
[36m[16:41:03] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/ec/ec-ddff435e.ckpt[0m
[36m[16:41:03] DEBUG: Getting class ECForGraphTCN from module gnn_tracking.models.edge_classifier[0m
[36m[16:41:03] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction[0m
[36m[16:41:03] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[16:41:03] DEBUG: Getting class HaughtyFocalLoss from module gnn_tracking.metrics.losses[0m
[36m[16:41:03] DEBUG: Checkpoint loaded. Model ready to go.[0m
  rank_zero_warn(
  rank_zero_warn(


In [10]:
model.hparams

"L_hc":                     3
"alpha_hc":                 0.5
"e_dim":                    192
"ec":                       {'class_path': 'gnn_tracking.models.edge_classifier.ECForGraphTCN', 'init_args': {'node_indim': 22, 'edge_indim': 44, 'interaction_node_dim': 128, 'interaction_edge_dim': 128, 'hidden_dim': 128, 'L_ec': 3, 'alpha': 0.35, 'residual_type': 'skip1', 'use_intermediate_edge_embeddings': False, 'use_node_embedding': True, 'residual_kwargs': None}}
"ec_threshold":             0.2
"edge_indim":               44
"feed_edge_weights":        True
"h_dim":                    192
"h_outdim":                 12
"hidden_dim":               192
"mask_orphan_nodes":        False
"node_indim":               22
"use_ec_embeddings_for_hc": True

In [11]:
ckpt = torch.load("/home/kl5675/ray_results/613788bb-continued/TCNFromGCTrainable_04b2e3ce_2_val_batch_size=1,adam_amsgrad=False,adam_beta1=0.9000,adam_beta2=0.9990,adam_eps=0.0000,adam_weight__2023-06-14_23-29-18/checkpoint_000018/checkpoint.pt", map_location="cpu")
state_dct = {k: v for k, v in ckpt["model_state_dict"].items() if not k.startswith("_gtcn.ec")}
model.load_state_dict(state_dct, strict=False)

_IncompatibleKeys(missing_keys=['_gtcn.ec.ec_node_encoder.layers.0.weight', '_gtcn.ec.ec_node_encoder.layers.2.weight', '_gtcn.ec.ec_edge_encoder.layers.0.weight', '_gtcn.ec.ec_edge_encoder.layers.2.weight', '_gtcn.ec.ec_resin.network.layers.0.relational_model.layers.0.weight', '_gtcn.ec.ec_resin.network.layers.0.relational_model.layers.0.bias', '_gtcn.ec.ec_resin.network.layers.0.relational_model.layers.2.weight', '_gtcn.ec.ec_resin.network.layers.0.relational_model.layers.2.bias', '_gtcn.ec.ec_resin.network.layers.0.relational_model.layers.4.weight', '_gtcn.ec.ec_resin.network.layers.0.relational_model.layers.4.bias', '_gtcn.ec.ec_resin.network.layers.0.object_model.layers.0.weight', '_gtcn.ec.ec_resin.network.layers.0.object_model.layers.0.bias', '_gtcn.ec.ec_resin.network.layers.0.object_model.layers.2.weight', '_gtcn.ec.ec_resin.network.layers.0.object_model.layers.2.bias', '_gtcn.ec.ec_resin.network.layers.0.object_model.layers.4.weight', '_gtcn.ec.ec_resin.network.layers.0.objec

In [12]:


lmodel = TCModule(
    model=model,
    lw_repulsive=0.743380428762342,
    lw_background=0.0041,
    preproc=MLGraphConstruction.from_chkpt(
        ml_chkpt_path=model_exchange_path / "gc" / "gc-7dce6aff.ckpt",
        use_embedding_features=True,
    )
)

[36m[16:41:14] DEBUG: Getting class MLModule from module gnn_tracking.training.ml[0m
[36m[16:41:14] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/gc/gc-7dce6aff.ckpt[0m
[36m[16:41:14] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[16:41:14] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses[0m
[36m[16:41:14] DEBUG: Checkpoint loaded. Model ready to go.[0m


In [16]:


pprint.pprint(dict(lmodel.hparams))

{'background_loss': {'class_path': 'gnn_tracking.metrics.losses.BackgroundLoss',
                     'init_args': {'sb': 0.1}},
 'cluster_scanner': None,
 'lw_background': 0.0041,
 'lw_repulsive': 0.743380428762342,
 'model': {'class_path': 'gnn_tracking.models.track_condensation_networks.PreTrainedECGraphTCN',
           'init_args': {'L_hc': 3,
                         'alpha_hc': 0.5,
                         'e_dim': 192,
                         'ec': {'class_path': 'gnn_tracking.models.edge_classifier.ECForGraphTCN',
                                'init_args': {'L_ec': 3,
                                              'alpha': 0.35,
                                              'edge_indim': 44,
                                              'hidden_dim': 128,
                                              'interaction_edge_dim': 128,
                                              'interaction_node_dim': 128,
                                              'node_indim': 22,
         

In [17]:
trainer = Trainer(accelerator="cpu", max_steps=0, num_sanity_val_steps=0)

  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [18]:
trainer.fit(lmodel, dm)

  rank_zero_warn(
[32m[16:42:57] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[16:42:57] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[16:42:57] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[16:42:57] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29004_s0.pt[0m

  | Name            | Type                 | Params
---------------------------------------------------------
0 | model           | PreTrainedECGraphTCN | 1.8 M 
1 | preproc         | MLGraphConstruction  | 1.3 M 
2 | potential_loss  | PotentialLoss        | 0     
3 | background_loss | BackgroundLoss       | 0     
----------

In [19]:
trainer.save_checkpoint(model_exchange_path / "tc" / "tc-04b2e3ce.ckpt")

## Try restore

In [8]:
lmodel = TCModule.load_from_checkpoint(
    model_exchange_path / "tc" / "tc-04b2e3ce.ckpt"
)

[36m[17:03:17] DEBUG: Getting class PreTrainedECGraphTCN from module gnn_tracking.models.track_condensation_networks[0m
[36m[17:03:17] DEBUG: Getting class ECForGraphTCN from module gnn_tracking.models.edge_classifier[0m
  rank_zero_warn(
  rank_zero_warn(
[36m[17:03:17] DEBUG: Getting class MLGraphConstruction from module gnn_tracking.models.graph_construction[0m
[36m[17:03:17] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[17:03:17] DEBUG: Getting class PotentialLoss from module gnn_tracking.metrics.losses[0m
[36m[17:03:17] DEBUG: Getting class BackgroundLoss from module gnn_tracking.metrics.losses[0m


In [9]:
pprint.pprint(dict(lmodel.hparams))

{'background_loss': {'class_path': 'gnn_tracking.metrics.losses.BackgroundLoss',
                     'init_args': {'sb': 0.1}},
 'cluster_scanner': None,
 'lw_background': 0.0041,
 'lw_repulsive': 0.743380428762342,
 'model': {'class_path': 'gnn_tracking.models.track_condensation_networks.PreTrainedECGraphTCN',
           'init_args': {'L_hc': 3,
                         'alpha_hc': 0.5,
                         'e_dim': 192,
                         'ec': {'class_path': 'gnn_tracking.models.edge_classifier.ECForGraphTCN',
                                'init_args': {'L_ec': 3,
                                              'alpha': 0.35,
                                              'edge_indim': 44,
                                              'hidden_dim': 128,
                                              'interaction_edge_dim': 128,
                                              'interaction_node_dim': 128,
                                              'node_indim': 22,
         

In [13]:
trainer = Trainer(accelerator="cpu", max_steps=1, num_sanity_val_steps=0)

  rank_zero_warn(
GPU available: True (cuda), used: False
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_warn(


In [None]:
trainer.fit(lmodel, dm)

  rank_zero_warn(
[32m[17:05:12] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[17:05:12] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_1/data21999_s0.pt[0m
[32m[17:05:12] INFO: DataLoader will load 1 graphs (out of 1000 available).[0m
[36m[17:05:12] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v5/part_9/data29000_s0.pt[0m

  | Name            | Type                 | Params
---------------------------------------------------------
0 | model           | PreTrainedECGraphTCN | 1.8 M 
1 | preproc         | MLGraphConstruction  | 1.3 M 
2 | potential_loss  | PotentialLoss        | 0     
3 | background_loss | BackgroundLoss       | 0     
----------

Training: 0it [00:00, ?it/s]

In [1]:
print('test')

test
