In [1]:
from gnn_tracking.metrics.losses.oc import CondensationLossTiger, CondensationLossRG
from gnn_tracking.training.callbacks import PrintValidationMetrics
from gnn_tracking.training.tc import TCModule
from gnn_tracking.models.graph_construction import MLGraphConstructionFromChkpt
from gnn_tracking.models.track_condensation_networks import PreTrainedECGraphTCN
import torch

## Performance testing


In [2]:
torch.cuda.empty_cache()
print(torch.cuda.memory_summary(device=None, abbreviated=False))

|                  PyTorch CUDA memory summary, device ID 0                 |
|---------------------------------------------------------------------------|
|            CUDA OOMs: 0            |        cudaMalloc retries: 0         |
|        Metric         | Cur Usage  | Peak Usage | Tot Alloc  | Tot Freed  |
|---------------------------------------------------------------------------|
| Allocated memory      |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------------------|
| Active memory         |      0 B   |      0 B   |      0 B   |      0 B   |
|       from large pool |      0 B   |      0 B   |      0 B   |      0 B   |
|       from small pool |      0 B   |      0 B   |      0 B   |      0 B   |
|---------------------------------------------------------------

In [22]:
from pathlib import Path
ml_chkpt_path=Path("/home/kl5675/Documents/23/git_sync/hyperparameter_optimization2/scripts/full_detector/lightning_logs/merciful-reindeer-of-coffee/checkpoints/epoch=79-step=72000.compat.ckpt")
ckpt = torch.load(ml_chkpt_path)

In [10]:
ckpt

{'epoch': 136,
 'global_step': 1060791,
 'pytorch-lightning_version': '2.0.4',
 'state_dict': OrderedDict([('model._latent_normalization',
               tensor([0.3039], device='cuda:0')),
              ('model._encoder.weight',
               tensor([[ 4.4126e-01, -4.6154e-01, -7.6596e-02,  ...,  5.4795e+00,
                         4.1984e-01,  7.8441e-01],
                       [-6.8870e-01,  2.7098e+00,  1.3297e-01,  ..., -2.9176e+00,
                        -3.2144e-01,  3.9957e-03],
                       [-3.7865e-01, -2.4479e-02, -6.0885e-02,  ..., -5.5328e-01,
                        -5.2481e-01,  5.3502e-01],
                       ...,
                       [ 1.3068e-01,  2.2458e+00, -6.6728e-02,  ..., -8.7599e+00,
                        -3.5665e-02, -2.5705e-01],
                       [ 3.0357e-03,  4.7783e-01,  4.2391e-03,  ...,  2.2119e+00,
                         1.8651e+00, -1.1115e-02],
                       [ 2.0886e-02, -7.2944e+00,  1.6007e-01,  ...,  1.6530e

In [11]:
import copy
def make_ckpt_compatible(ckpt):
    margs = ckpt["hyper_parameters"]
    largs = ckpt["hyper_parameters"]["loss_fct"]["init_args"]
    for key, value in copy.deepcopy(margs).items():
        if key.startswith("lw_"):
            largs[key] = value
            del margs[key]
    ckpt["hyper_parameters"] = margs
    ckpt["hyper_parameters"]["loss_fct"]["init_args"] = largs
    ckpt["hyper_parameters"]["loss_fct"]["init_args"]["pt_thld"] = ckpt["hyper_parameters"]["loss_fct"]["init_args"].pop("attr_pt_thld")
    ckpt["hyper_parameters"]["loss_fct"]["class_path"] = ckpt["hyper_parameters"]["loss_fct"]["class_path"].replace("losses", "losses.metric_learning")
    ckpt["hyper_parameters"]["model"]["init_args"]["alpha"] = 1-ckpt["hyper_parameters"]["model"]["init_args"].pop("beta")
    return ckpt

ckpt = make_ckpt_compatible(ckpt)
compat_name = ml_chkpt_path.parent / (ml_chkpt_path.stem + "_compatible.ckpt")
torch.save(ckpt, compat_name)

In [6]:
ckpt["hyper_parameters"]

{'model': {'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN',
  'init_args': {'in_dim': 14,
   'hidden_dim': 256,
   'out_dim': 24,
   'depth': 5,
   'alpha': 0.5}},
 'preproc': None,
 'loss_fct': {'class_path': 'gnn_tracking.metrics.losses.metric_learning.GraphConstructionHingeEmbeddingLoss',
  'init_args': {'r_emb': 1,
   'max_num_neighbors': 256,
   'p_attr': 2.0,
   'p_rep': 2.0,
   'lw_repulsive': 0.006,
   'pt_thld': 0.9}},
 'gc_scanner': {'class_path': 'gnn_tracking.graph_construction.k_scanner.GraphConstructionKNNScanner',
  'init_args': {'ks': [7, 8, 9, 10, 11, 12, 13, 14, 15],
   'targets': (0.8, 0.85, 0.88, 0.9, 0.93, 0.95, 0.97, 0.99),
   'max_radius': 1.0,
   'pt_thld': 0.9,
   'max_eta': 4.0,
   'subsample_pids': None,
   'max_edges': 5000000}}}

In [14]:
torch.load(compat_name)["hyper_parameters"]["model"]

{'class_path': 'gnn_tracking.models.graph_construction.GraphConstructionFCNN',
 'init_args': {'in_dim': 14,
  'hidden_dim': 256,
  'out_dim': 8,
  'depth': 6,
  'alpha': 0.4}}

In [15]:

model = PreTrainedECGraphTCN(
      ec=None,
      node_indim= 38,
      edge_indim= 76,
      h_dim= 192,
      e_dim= 192,
      hidden_dim= 192,
      h_outdim= 24,
      L_hc= 5,
      alpha_latent= 0.5,
      n_embedding_coords= 24,
)
preproc = MLGraphConstructionFromChkpt(
    ml_chkpt_path=compat_name,
    max_num_neighbors= 15,
    max_radius= 1.,
    use_embedding_features= True,
    build_edge_features= True,
)


/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/utilities/parsing.py:198: Attribute 'hc_in' is an instance of `nn.Module` and is already saved during checkpointing. It is recommended to ignore them using `self.save_hyperparameters(ignore=['hc_in'])`.
[36m[11:43:17] DEBUG: Getting class MLModule from module gnn_tracking.training.ml[0m
[36m[11:43:30] DEBUG: Loading checkpoint /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/model_exchange/gc/quiet-origami-prawn_compatible.ckpt[0m
[36m[11:43:31] DEBUG: Getting class GraphConstructionFCNN from module gnn_tracking.models.graph_construction[0m
[36m[11:43:31] DEBUG: Getting class GraphConstructionHingeEmbeddingLoss from module gnn_tracking.metrics.losses.metric_learning[0m
[36m[11:43:31] DEBUG: Getting class GraphConstructionKNNScanner from module gnn_tracking.graph_construction.k_scanner[0m
[36m[11:43:31] DEBUG: Checkpoint loaded. Model ready to go.[0m


In [16]:

condensation_loss = CondensationLossTiger(
      q_min= 0.01,
      max_n_rep=100_000,
)


In [17]:
condensation_loss = CondensationLossRG(
      max_num_neighbors=50,
)

In [18]:
oc = TCModule(
    model=model,
    preproc=preproc,
    loss_fct=condensation_loss,
    cluster_scanner=None,
)

In [19]:
from gnn_tracking.utils.loading import TrackingDataModule

dm = TrackingDataModule(
    identifier="point_clouds_v10",
    train=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/"
        ],
        sample_size=1000,
        # If you run into memory issues, reduce this
    ),
    val=dict(
        dirs=[
            "/scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/"
        ],
        stop=5
    ),
)


In [20]:
import pytorch_lightning as pl

trainer = pl.Trainer(
    max_epochs=1,
    callbacks=[PrintValidationMetrics()],
    # fast_dev_run=True,
)


/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/lightning_fabric/plugins/environments/slurm.py:191: The `srun` command is available on your system but is not used. HINT: If your intention is to run Lightning on SLURM, prepend your python command with `srun` like so: srun python /scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3 ...
GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs


In [21]:
trainer.fit(oc, dm)



You are using a CUDA device ('NVIDIA A100-SXM4-80GB') that has Tensor Cores. To properly utilize them, you should set `torch.set_float32_matmul_precision('medium' | 'high')` which will trade-off precision for performance. For more details, read https://pytorch.org/docs/stable/generated/torch.set_float32_matmul_precision.html#torch.set_float32_matmul_precision
[32m[11:43:38] INFO: DataLoader will load 900 graphs (out of 900 available).[0m
[36m[11:43:38] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_1/data21999_s0.pt[0m
[32m[11:43:38] INFO: DataLoader will load 5 graphs (out of 1000 available).[0m
[36m[11:43:38] DEBUG: First graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data29000_s0.pt, last graph is /scratch/gpfs/IOJALVO/gnn-tracking/object_condensation/point_clouds_v8/part_9/data2

Sanity Checking: |                               | 0/? [00:00<?, ?it/s]

/scratch/gpfs/kl5675/micromamba/envs/gnn/lib/python3.10/site-packages/pytorch_lightning/trainer/connectors/data_connector.py:441: The 'val_dataloader' does not have many workers which may be a bottleneck. Consider increasing the value of the `num_workers` argument` to `num_workers=2` in the `DataLoader` to improve performance.


Sanity Checking DataLoader 0:   0%|              | 0/2 [00:00<?, ?it/s]

AssertionError: Expected feature dimension 38, got 22

## Consistency testing

In [9]:
from pathlib import Path
import sys

repo_path = Path("/home/kl5675/Documents/23/git_sync/gnn_tracking/tests")
assert repo_path.is_dir()
sys.path.append(str(repo_path))


from test_losses import generate_test_data

from torch import Tensor as T
from gnn_tracking.metrics.losses.oc import *
from gnn_tracking.metrics.losses.oc import _first_occurrences, _square_distances
from torch_cluster import radius_graph


In [10]:
# RG

td = generate_test_data(n_nodes=100, n_particles=10, rng=np.random.default_rng(seed=0))
beta=td.beta
x=td.x
particle_id=td.particle_id
reconstructable=td.reconstructable
pt=td.pt
eta=td.eta
q_min=0.01
radius_threshold=1
max_num_neighbors=256

mask = get_good_node_mask_tensors(
    pt=pt,
    particle_id=particle_id,
    reconstructable=reconstructable,
    eta=eta,
)


# For better readability, variables that are only relevant in one "block"
# are prefixed with an underscore
# _j means indexed as hits (... x n_hits)
# _k means indexed as objects (... x n_objects_of_interest)
# _e means indexed by edge
# where n_objects_of_interest = len(unique(particle_id[mask]))

# -- 1. Determine indices of condensation points (CPs) and q --
_sorted_indices_j = torch.argsort(beta, descending=True)
_pids_sorted = particle_id[_sorted_indices_j]
_alphas = _sorted_indices_j[_first_occurrences(_pids_sorted)]
# Index of condensation points in node array
# Only particles of interest have CPs, in particular no noise hits or low pt hits 
alphas_k = _alphas[mask[_alphas]]
assert alphas_k.size()[0] > 0, "No particles found, cannot evaluate loss"
# "Charge"
q_j = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q_j).any(), "q contains NaNs"

# -- 2. Edges for repulsion loss --
_radius_edges = radius_graph(
    x=x, r=radius_threshold, max_num_neighbors=max_num_neighbors, loop=False
)
# Now filter out everything that doesn't include a CP or connects two hits of the
# same particle
_to_cp_e = torch.isin(_radius_edges[0], alphas_k)
_is_repulsive_e = particle_id[_radius_edges[0]] != particle_id[_radius_edges[1]]
# Since noise/low pt does not have CPs, they don't repel from each other 
repulsion_edges_e = _radius_edges[:, _is_repulsive_e & _to_cp_e]

# -- 3. Edges for attractive loss --
# 1D array (n_nodes): 1 for CPs, 0 otherwise
is_cp_j = torch.zeros(len(particle_id), dtype=bool, device=x.device).scatter_(
    0, alphas_k, 1
)
# hit-indices of all non-CPs
_non_cp_indices = torch.nonzero(~is_cp_j & mask).squeeze()
print(len(_non_cp_indices))
# for each non-CP hit, the index of the corresponding CP
corresponding_alpha = alphas_k[
    torch.searchsorted(particle_id[alphas_k], particle_id[_non_cp_indices])
]
# Insert alpha indices into their respective positions to form attraction edges
attraction_edges_e = torch.stack((
    _non_cp_indices,
    corresponding_alpha
))
print(attraction_edges_e)

# -- 4. Calculate loss --
# Protect against sqrt not being differentiable around 0
attraction_distances_e = _square_distances(attraction_edges_e, x)
attraction_distances_e
attraction_edges_e.shape

15
tensor([[ 3, 15, 28, 29, 37, 48, 50, 62, 67, 74, 82, 84, 92, 93, 94],
        [20, 88, 88, 88, 20, 20, 88, 88, 88, 88, 20, 88, 20, 20, 88]])


torch.Size([2, 15])

In [11]:
_non_cp_indices

tensor([ 3, 15, 28, 29, 37, 48, 50, 62, 67, 74, 82, 84, 92, 93, 94])

In [12]:
torch.stack((
    torch.nonzero(~is_cp_j & mask).squeeze(),
    corresponding_alpha
))

tensor([[ 3, 15, 28, 29, 37, 48, 50, 62, 67, 74, 82, 84, 92, 93, 94],
        [20, 88, 88, 88, 20, 20, 88, 88, 88, 88, 20, 88, 20, 20, 88]])

In [13]:
td = generate_test_data(n_nodes=100, n_particles=10, rng=np.random.default_rng(seed=0))
beta=td.beta
x=td.x
object_id=td.particle_id
reconstructable=td.reconstructable
pt=td.pt
eta=td.eta
q_min=0.01
radius_threshold=1
max_num_neighbors=256
noise_threshold=0
max_n_rep=0

object_mask = get_good_node_mask_tensors(
    pt=pt,
    particle_id=particle_id,
    reconstructable=reconstructable,
    eta=eta,
)

# To protect against nan in divisions
eps = 1e-9

# x: n_nodes x n_outdim
unique_oids = torch.unique(object_id[object_mask])
assert len(unique_oids) > 0, "No particles found, cannot evaluate loss"
# n_nodes x n_pids
# The nodes in every column correspond to the hits of a single particle and
# should attract each other
# Note that a condensation point attracts itself, but since the distance
# will be 0, it doesn't matter
attractive_mask_jk = object_id.view(-1, 1) == unique_oids.view(1, -1)
print(torch.nonzero(attractive_mask_jk))

q = torch.arctanh(beta) ** 2 + q_min
assert not torch.isnan(q).any(), "q contains NaNs"
# Index of condensation points in node array
alphas_k = torch.argmax(q.view(-1, 1) * attractive_mask_jk, dim=0)

# n_objs x n_outdim
x_k = x[alphas_k]
# 1 x n_objs
q_k = q[alphas_k].view(1, -1)

dist_j_k = torch.cdist(x, x_k)

qw_j_k = q.view(-1, 1) * q_k

att_norm_k = (attractive_mask_jk.sum(dim=0) + eps) * len(unique_oids)
qw_att = (qw_j_k / att_norm_k)[attractive_mask_jk]

# Attractive potential/loss
v_att = (qw_att * torch.square(dist_j_k[attractive_mask_jk])).sum()

repulsive_mask = (~attractive_mask_jk) & (dist_j_k < 1)
n_rep_k = (~attractive_mask_jk).sum(dim=0)
n_rep = repulsive_mask.sum()
# Don't normalize to repulsive_mask, it includes the dist < 1 count,
# (less points within the radius 1 ball should translate to lower loss)
rep_norm = (n_rep_k + eps) * len(unique_oids)
if n_rep > max_n_rep > 0:
    sampling_freq = max_n_rep / n_rep
    sampling_mask = (
        torch.rand_like(repulsive_mask, dtype=torch.float16) < sampling_freq
    )
    repulsive_mask &= sampling_mask
    rep_norm *= sampling_freq
qw_rep = (qw_j_k / rep_norm)[repulsive_mask]
v_rep = (qw_rep * (1 - dist_j_k[repulsive_mask])).sum()

l_coward = torch.mean(1 - beta[alphas_k])
not_noise_j = object_id > noise_threshold
l_noise = torch.mean(beta[~not_noise_j])

dist_j_k[attractive_mask_jk]
attractive_mask_jk.sum()


tensor([[ 3,  0],
        [15,  1],
        [20,  0],
        [28,  1],
        [29,  1],
        [37,  0],
        [48,  0],
        [50,  1],
        [62,  1],
        [67,  1],
        [74,  1],
        [82,  0],
        [84,  1],
        [88,  1],
        [92,  0],
        [93,  0],
        [94,  1]])


tensor(17)