# Guess new loss weights

After changing the definitions of the loss functions, we need to find new weights to balance attractive vs repulsive.
As a first starting point, let's evaluate the same data on a trained GC and get a rough ratio.

In [1]:
from pathlib import Path

from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt
from gnn_tracking.metrics.losses.metric_learning import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.training.ml import MLModule

In [15]:
from pathlib import Path
lightning_home = Path("/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/gc")
assert lightning_home.is_dir()
chkpt_name = "quiet-origami-prawn_compatible.ckpt"
ml_chkpt_path= lightning_home / chkpt_name
assert ml_chkpt_path.is_file()

In [16]:
ml_module =  MLModule.load_from_checkpoint(
    ml_chkpt_path,
)

[36m[11:54:12] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[11:54:12] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses.metric_learning[0m
[36m[11:54:12] DEBUG: Getting class GraphConstructionKNNScanner from module gnn_tracking.graph_construction.k_scanner[0m


In [19]:
dm = TrackingDataModule(
    identifier="point_clouds_v8",
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/"
        ],
        sample_size=1000,
        # If you run into memory issues, reduce this
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/"
        ],
        stop=5
    ),
)

In [20]:
dm.setup("fit")
data = dm.datasets["train"][0]

[32m[11:54:29] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[11:54:29] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21999_s0.pt[0m
[32m[11:54:29] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[11:54:29] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29004_s0.pt[0m


In [21]:
import pytorch_lightning as pl

trainer = pl.Trainer()
trainer.validate(ml_module, dm)

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
[32m[11:54:30] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[11:54:30] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29004_s0.pt[0m
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]
/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.

Validation DataLoader 0: 100%|███████████| 5/5 [00:31<00:00,  0.16it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.


Validation DataLoader 0: 100%|███████████| 5/5 [00:31<00:00,  0.16it/s]


[{'attractive': 0.27874574065208435,
  'repulsive': 0.42850080132484436,
  'attractive_weighted': 0.27874574065208435,
  'repulsive_weighted': 0.02571004629135132,
  'total': 0.30445581674575806,
  'n_edges_frac_segment50_80': 106762.140625,
  'k_at_segment50_80': 1.9612022638320923,
  'frac75_at_segment50_80': 0.5798446536064148,
  'frac100_at_segment50_80': 0.4761907160282135,
  'efficiency_at_segment50_80': 0.2204710990190506,
  'purity_at_segment50_80': 0.6520455479621887,
  'n_edges_frac_segment50_85': 121764.2890625,
  'k_at_segment50_85': 2.244067668914795,
  'frac75_at_segment50_85': 0.6489571332931519,
  'frac100_at_segment50_85': 0.5457428097724915,
  'efficiency_at_segment50_85': 0.24987787008285522,
  'purity_at_segment50_85': 0.6401358246803284,
  'n_edges_frac_segment50_88': 134387.296875,
  'k_at_segment50_88': 2.4836151599884033,
  'frac75_at_segment50_88': 0.6975961327552795,
  'frac100_at_segment50_88': 0.597761332988739,
  'efficiency_at_segment50_88': 0.274264425039

In [24]:
attractive = 0.16546812653541565
repulsive = 0.41141024231910706
attractive / repulsive

0.40219739207919775

## Now loading legacy 

In [2]:
from pathlib import Path

from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt
from gnn_tracking.metrics.losses import GraphConstructionHingeEmbeddingLoss
from gnn_tracking.utils.loading import TrackingDataModule
from gnn_tracking.training.ml import MLModule

In [3]:
# We already converted this checkpoint to the new version
ml_chkpt_path = Path("/home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/full_detector/lightning_logs/merciful-reindeer-of-coffee/checkpoints/epoch=79-step=72000.compat.ckpt")
assert ml_chkpt_path.is_file()

In [4]:
ml_module =  MLModule.load_from_checkpoint(
    ml_chkpt_path,
)

[36m[17:16:27] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[17:16:27] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses[0m
[36m[17:16:27] DEBUG: Getting class GraphConstructionKNNScanner from module gnn_tracking.graph_construction.k_scanner[0m


In [5]:
dm = TrackingDataModule(
    identifier="point_clouds_v10",
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/"
        ],
        sample_size=1000,
        # If you run into memory issues, reduce this
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/"
        ],
        stop=5
    ),
)

In [6]:
dm.setup("fit")
data = dm.datasets["train"][0]

[32m[17:16:30] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[17:16:30] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21999_s0.pt[0m
[32m[17:16:30] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[17:16:30] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29004_s0.pt[0m


In [7]:
import pytorch_lightning as pl

trainer = pl.Trainer()
trainer.validate(ml_module, dm)

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...
You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_flo

Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:44<00:00,  0.11it/s]

NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.
NaN or Inf found in input tensor.


Validation DataLoader 0: 100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 5/5 [00:44<00:00,  0.11it/s]


[{'attractive': 0.14644861221313477,
  'repulsive': 29.345333099365234,
  'attractive_weighted': 0.14644861221313477,
  'repulsive_weighted': 0.17607201635837555,
  'total': 0.3225206434726715,
  'n_edges_frac_segment50_80': 396683.8125,
  'k_at_segment50_80': 7.0,
  'frac75_at_segment50_80': 0.6447610259056091,
  'frac100_at_segment50_80': 0.5249454379081726,
  'efficiency_at_segment50_80': 0.4928075075149536,
  'purity_at_segment50_80': 0.3488660752773285,
  'n_edges_frac_segment50_85': 404595.9375,
  'k_at_segment50_85': 7.141140460968018,
  'frac75_at_segment50_85': 0.648024320602417,
  'frac100_at_segment50_85': 0.5289484858512878,
  'efficiency_at_segment50_85': 0.4998131990432739,
  'purity_at_segment50_85': 0.3463244140148163,
  'n_edges_frac_segment50_88': 553604.0,
  'k_at_segment50_88': 9.80606460571289,
  'frac75_at_segment50_88': 0.6954953670501709,
  'frac100_at_segment50_88': 0.5877106785774231,
  'efficiency_at_segment50_88': 0.6181542277336121,
  'purity_at_segment50_8

In [9]:
attractive = 0.14644861221313477
repulsive = 29.345333099365234
repulsive_weighted = 0.17607201635837555
attractive/repulsive_weighted

0.8317540472476583