In [1]:
%load_ext autoreload
%autoreload 2

In [2]:
import sys

print(sys.executable)

/home/guevel/projects/annotix_all/DiffMS/.venv/bin/python


In [3]:
from omegaconf import DictConfig, OmegaConf
from hydra import compose,initialize


# Initialize once at the beginning of your notebook
initialize(config_path="configs", version_base="1.3")

# Compose the config (same as Hydra would do on CLI)
cfg: DictConfig = compose(config_name="config")

# Now you can use it like a normal dict
print(OmegaConf.to_yaml(cfg))


general:
  name: dev
  parent_dir: .
  wandb: online
  wandb_name: mass_spec_exp
  gpus: 1
  decoder: models/checkpoints/decoder.ckpt
  encoder: models/checkpoints/encoder_msg.ckpt
  resume: null
  test_only: null
  load_weights: models/checkpoints/diffms_msg.ckpt
  encoder_finetune_strategy: null
  decoder_finetune_strategy: null
  check_val_every_n_epochs: 1
  sample_every_val: 1000
  val_samples_to_generate: 100
  test_samples_to_generate: 100
  log_every_steps: 50
  evaluate_all_checkpoints: false
  checkpoint_strategy: last
model:
  transition: marginal
  model: graph_tf
  diffusion_steps: 500
  diffusion_noise_schedule: cosine
  n_layers: 5
  extra_features: all
  hidden_mlp_dims:
    X: 256
    E: 128
    'y': 2048
  hidden_dims:
    dx: 256
    de: 64
    dy: 1024
    n_head: 8
    dim_ffX: 256
    dim_ffE: 128
    dim_ffy: 1024
  encoder_hidden_dim: 512
  encoder_magma_modulo: 2048
  lambda_train:
  - 0
  - 1
  - 0
train:
  n_epochs: 75
  batch_size: 96
  eval_batch_size: 128


In [4]:
import logging

from diffms.analysis.visualization import MolecularVisualization
from diffms.datasets import spec2mol_dataset
from diffms.diffusion.extra_features import DummyExtraFeatures, ExtraFeatures
from diffms.metrics.molecular_metrics_discrete import TrainMolecularMetricsDiscrete
from diffms.diffusion.extra_features_molecular import ExtraMolecularFeatures

dataset_config = cfg["dataset"]

if dataset_config["name"] not in ("canopus", "msg"):
    raise NotImplementedError("Unknown dataset {}".format(cfg["dataset"]))

print('Creating datamodule')
datamodule = spec2mol_dataset.Spec2MolDataModule(cfg) # TODO: Add hyper for n_bits

print('Getting dataste infos')
dataset_infos = spec2mol_dataset.Spec2MolDatasetInfos(datamodule, cfg)

print('Making domain features')
domain_features = ExtraMolecularFeatures(dataset_infos=dataset_infos)

print('Making extra features')
if cfg.model.extra_features is not None:
    extra_features = ExtraFeatures(cfg.model.extra_features, dataset_info=dataset_infos)
else:
    extra_features = DummyExtraFeatures()

print('Compute input/output dims')
dataset_infos.compute_input_output_dims(datamodule=datamodule, extra_features=extra_features, domain_features=domain_features)

logging.info("Dataset infos:", dataset_infos.output_dims)

print('Training metrics.')
train_metrics = TrainMolecularMetricsDiscrete(dataset_infos)

# We do not evaluate novelty during training
print('Get viz tools.')
visualization_tools = MolecularVisualization(cfg.dataset.remove_h, dataset_infos=dataset_infos)

model_kwargs = {
    'dataset_infos': dataset_infos,
    'train_metrics': train_metrics,
    'visualization_tools': visualization_tools,
    'extra_features': extra_features, 
    'domain_features': domain_features,
}

Creating datamodule


231104it [00:01, 185871.46it/s]
231104it [00:34, 6692.45it/s]
231104it [00:00, 3632018.76it/s]


Getting dataste infos
Making domain features
Making extra features
Compute input/output dims

































































































































Training metrics.
Get viz tools.


In [5]:
import os
from diffms import ROOT
from diffms.diffusion_model_spec2mol import Spec2MolDenoisingDiffusion
from diffms.spec2mol_main import load_weights

model = Spec2MolDenoisingDiffusion(cfg=cfg, **model_kwargs)

weight_path = ROOT / cfg.general.load_weights
if not os.path.isfile:
    raise ValueError(f"The path indicated does not exist {weight_path}")

logging.info(f"Loading weights from {weight_path}")
model = load_weights(model, weight_path)

In [7]:
import torch

dataloader = datamodule.test_dataloader()

for b in dataloader:
    b = {
        k: v.to("cuda") if isinstance(v, torch.Tensor) else v
        for k, v in b.items()
    }
    break




































































































































































































































In [8]:
!nvidia-smi

Mon Sep  8 11:06:21 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    On   | 00000000:3B:00.0 Off |                  Off |
| 30%   32C    P2    57W / 230W |    815MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [9]:
model.to("cuda")

Spec2MolDenoisingDiffusion(
  (train_loss): TrainLossDiscrete(
    (node_loss): CrossEntropyMetric()
    (edge_loss): CrossEntropyMetric()
    (y_loss): CrossEntropyMetric()
  )
  (val_nll): NLL()
  (val_X_kl): SumExceptBatchKL()
  (val_E_kl): SumExceptBatchKL()
  (val_X_logp): SumExceptBatchMetric()
  (val_E_logp): SumExceptBatchMetric()
  (val_k_acc): K_ACC_Collection(
    (metrics): ModuleDict(
      (acc_at_1): K_ACC()
      (acc_at_2): K_ACC()
      (acc_at_3): K_ACC()
      (acc_at_4): K_ACC()
      (acc_at_5): K_ACC()
      (acc_at_6): K_ACC()
      (acc_at_7): K_ACC()
      (acc_at_8): K_ACC()
      (acc_at_9): K_ACC()
      (acc_at_10): K_ACC()
      (acc_at_11): K_ACC()
      (acc_at_12): K_ACC()
      (acc_at_13): K_ACC()
      (acc_at_14): K_ACC()
      (acc_at_15): K_ACC()
      (acc_at_16): K_ACC()
      (acc_at_17): K_ACC()
      (acc_at_18): K_ACC()
      (acc_at_19): K_ACC()
      (acc_at_20): K_ACC()
      (acc_at_21): K_ACC()
      (acc_at_22): K_ACC()
      (acc_at_

In [10]:
!nvidia-smi

Mon Sep  8 11:06:56 2025       
+-----------------------------------------------------------------------------+
| NVIDIA-SMI 525.125.06   Driver Version: 525.125.06   CUDA Version: 12.0     |
|-------------------------------+----------------------+----------------------+
| GPU  Name        Persistence-M| Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp  Perf  Pwr:Usage/Cap|         Memory-Usage | GPU-Util  Compute M. |
|                               |                      |               MIG M. |
|   0  NVIDIA RTX A5000    On   | 00000000:3B:00.0 Off |                  Off |
| 30%   35C    P2    57W / 230W |   1157MiB / 24564MiB |      0%      Default |
|                               |                      |                  N/A |
+-------------------------------+----------------------+----------------------+
                                                                               
+-----------------------------------------------------------------------------+
| Proces

In [11]:
output, aux = model.encoder(b)

In [29]:
print(aux["pred_frag_fps"].numel()/1e6)
print(aux["h0"].numel()/1e6)

15.990784
0.065536


In [34]:
sum([v.numel() for v in  model.state_dict().values()]) / 1e6

84.912486

In [13]:
merge = getattr(cfg.dataset, 'merge', 'none')

data = b["graph"]
if model.merge == 'mist_fp':
    data.y = aux["int_preds"][-1]
if model.merge == 'merge-encoder_output-linear':
    encoder_output = aux['h0']
    data.y = model.merge_function(encoder_output)
elif model.merge == 'merge-encoder_output-mlp':
    encoder_output = aux['h0']
    data.y = model.merge_function(encoder_output)
elif model.merge == 'downproject_4096':
    data.y = model.merge_function(output)

In [14]:
from diffms import utils

dense_data, node_mask = utils.to_dense(data.x, data.edge_index, data.edge_attr, data.batch)
dense_data = dense_data.mask(node_mask)

In [None]:
for attr in ["X", "E", "y"]:
    value = getattr(dense_data, attr)
    if isinstance(value, torch.Tensor):
        setattr(dense_data, attr, value.to("cuda"))

node_mask = node_mask.to("cuda")

noisy_data = model.apply_noise(dense_data.X, dense_data.E, data.y, node_mask)
extra_data = model.compute_extra_data(noisy_data)

torch.cuda.empty_cache()

In [50]:
import torch
import gc

def get_cuda_variables(namespace=None):
    """
    List all variables in the given namespace (default: globals())
    that occupy space on CUDA devices.
    """
    if namespace is None:
        namespace = globals()
    
    cuda_vars = []
    total_mem = 0
    
    for name, var in namespace.items():
        try:
            if isinstance(var, torch.Tensor) and var.is_cuda:
                mem = var.element_size() * var.nelement()
                cuda_vars.append((name, type(var).__name__, mem / (1024**2)))
                total_mem += mem
            elif isinstance(var, torch.nn.Module):
                for p in var.parameters():
                    if p.is_cuda:
                        mem = p.element_size() * p.nelement()
                        cuda_vars.append((name, type(var).__name__, mem / (1024**2)))
                        total_mem += mem
                        break
        except Exception:
            pass
    
    print(f"Total CUDA memory by variables: {total_mem / (1024**2):.2f} MB")
    return cuda_vars

# Example usage in a notebook:
cuda_vars = get_cuda_variables()
for name, vtype, mem in cuda_vars:
    print(f"{name:20} | {vtype:15} | {mem:.2f} MB")


Total CUDA memory by variables: 2.59 MB
model                | Spec2MolDenoisingDiffusion | 0.02 MB
_9                   | Spec2MolDenoisingDiffusion | 0.02 MB
output               | Tensor          | 2.00 MB
node_mask            | Tensor          | 0.01 MB
_18                  | Tensor          | 0.28 MB
_44                  | Tensor          | 0.28 MB


In [54]:
pred = model.forward(noisy_data, extra_data, node_mask)
pred.X = dense_data.X
pred.Y = data.y

OutOfMemoryError: CUDA out of memory. Tried to allocate 632.00 MiB. GPU 0 has a total capacity of 23.68 GiB of which 64.12 MiB is free. Including non-PyTorch memory, this process has 23.61 GiB memory in use. Of the allocated memory 21.86 GiB is allocated by PyTorch, and 652.52 MiB is reserved by PyTorch but unallocated. If reserved but unallocated memory is large try setting PYTORCH_CUDA_ALLOC_CONF=expandable_segments:True to avoid fragmentation.  See documentation for Memory Management  (https://pytorch.org/docs/stable/notes/cuda.html#environment-variables)

In [None]:
import torch
from rdkit import Chem

true_E = torch.reshape(dense_data.E, (-1, dense_data.E.size(-1)))  # (bs * n * n, de)
masked_pred_E = torch.reshape(pred.E, (-1, pred.E.size(-1)))   # (bs * n * n, de)
mask_E = (true_E != 0.).any(dim=-1)

flat_true_E = true_E[mask_E, :]
flat_pred_E = masked_pred_E[mask_E, :]

true_mols = [Chem.inchi.MolFromInchi(data.get_example(idx).inchi) for idx in range(len(data))] # Is this correct?
predicted_mols = [list() for _ in range(len(data))]