## Profiling and Benchmarking XModalix
XModalix is rather slow compared to Varix, we assume this is because of the MultiModalDataSet and the CustomSampler. 
Before improving the code we profile and benchmark the components of the XModalix pipeline. We think the following modules could be relevant:
- MultiModalDataset
  - which is comprised of NumericDataset, or ImageDataset
- CoverageEnsuringSampler
- XModlixTrainer
  - which calls multiple GeneralTrainers
    - which is a child of BaseTrainer


We should also profile for four different cases:
  - single cell vs standard tabular
  - paired vs unpaired
  - image vs standard tabulas
  - image vs single cell

In [7]:
import os

p = os.getcwd()
d = "autoencodix_package"
if d not in p:
    raise FileNotFoundError(f"'{d}' not found in path: {p}")
os.chdir(os.sep.join(p.split(os.sep)[: p.split(os.sep).index(d) + 1]))
print(f"Changed to: {os.getcwd()}")


Changed to: /Users/maximilianjoas/development/autoencodix_package


In [8]:
import torch
import numpy as np
import pandas as pd
from torch import nn
from torch.profiler import profile, ProfilerActivity, record_function
import autoencodix as acx
from autoencodix.trainers import _xmodal_trainer, _general_trainer
from autoencodix.base import BaseTrainer, BaseDataset
from autoencodix.data import NumericDataset, MultiModalDataset, ImageDataset
from autoencodix.data._multimodal_dataset import CoverageEnsuringSampler
from autoencodix.utils.example_data import EXAMPLE_MULTI_SC, EXAMPLE_MULTI_BULK
import torch.utils.benchmark as benchmark



In [9]:
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

#### Run XModalix to get access to all attributes we want to benchmark

In [10]:
from autoencodix.configs.xmodalix_config import XModalixConfig
from autoencodix.configs.default_config import DataConfig, DataInfo, DataCase

clin_file = os.path.join("./data/XModalix-Tut-data/combined_clin_formatted.parquet")
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

xmodalix_config = XModalixConfig(
    checkpoint_interval=100,
    class_param="CANCER_TYPE",
    epochs=1,
    beta=0.1,
    gamma=10,
    delta_class=100,
    delta_pair=300,
    latent_dim=6,
    k_filter=1000,
    batch_size=512,
    learning_rate=0.0005,
    requires_paired=False,
    # float_precision="16-mixed"
    loss_reduction="sum",
    data_case=DataCase.IMG_TO_BULK,
    data_config=DataConfig(
        data_info={
            "img": DataInfo(
                file_path=img_root,
                img_height_resize=32,
                img_width_resize=32,
                data_type="IMG",
                scaling="STANDARD",
                translate_direction="to",
                pretrain_epochs=0,
            ),
            "rna": DataInfo(
                file_path=rna_file,
                data_type="NUMERIC",
                scaling="STANDARD",
                pretrain_epochs=0,
                translate_direction="from",
            ),
            "anno": DataInfo(file_path=clin_file, data_type="ANNOTATION", sep="\t"),
        },
        annotation_columns=["CANCER_TYPE_ACRONYM"],
    ),
)

xmodalix = acx.XModalix(config=xmodalix_config)
result = xmodalix.run()


reading parquet: data/XModalix-Tut-data/combined_rnaseq_formatted.parquet
reading parquet: ./data/XModalix-Tut-data/combined_clin_formatted.parquet
Given image size is possible, rescaling images to: 32x32
Successfully loaded 3230 images for img
anno key: rna
anno key: img
Converting 2261 images to torch.float32 tensors...
Converting 646 images to torch.float32 tensors...
Converting 323 images to torch.float32 tensors...
key: train, type: <class 'dict'>
key: valid, type: <class 'dict'>
key: test, type: <class 'dict'>
Dataset has UNPAIRED samples → using CoverageEnsuringSampler (2261 paired + 225 unpaired)
Dataset has UNPAIRED samples → using CoverageEnsuringSampler (323 paired + 32 unpaired)
Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 0
No pretraining for multi_bulk.rna
Check if we need to pretrain: img.img
pretrain epochs : 0
No pretraining for img.img
--- Epoch 1/1 ---
split: train, n_samples: 5885.0
Epoch 1/1 - Train Loss: 2821.8011
Sub-losses - adver_loss: 14.5485

#### Getting attributes I want to profiler

In [11]:
from autoencodix.data._multimodal_dataset import create_multimodal_collate_fn, CoverageEnsuringSampler
model = result.model
forward_fn = xmodalix._trainer._modalities_forward
loader = xmodalix._trainer._trainloader
dataset: MultiModalDataset = CoverageEnsuringSampler(multimodal_dataset=loader.dataset, batch_size=xmodalix_config.batch_size)
sampler = xmodalix._trainer
collate_fn = create_multimodal_collate_fn(multimodal_dataset=dataset)


In [12]:
activities = [ProfilerActivity.CPU]
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
    activities += [ProfilerActivity.CUDA]


sort_by_keyword = device + "_time_total"

with profile(activities=activities, record_shapes=True, profile_memory=True, with_stack=True,
    experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
    with record_function("model_inference"):
        
        xmodalix.fit()
        #forward_fn(next(iter(loader)))

print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))

Dataset has UNPAIRED samples → using CoverageEnsuringSampler (2261 paired + 225 unpaired)
Dataset has UNPAIRED samples → using CoverageEnsuringSampler (323 paired + 32 unpaired)


STAGE:2025-12-05 09:00:37 1697:14584 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 0
No pretraining for multi_bulk.rna
Check if we need to pretrain: img.img
pretrain epochs : 0
No pretraining for img.img
--- Epoch 1/1 ---
split: train, n_samples: 5885.0
Epoch 1/1 - Train Loss: 2811.3560
Sub-losses - adver_loss: 14.2768, aggregated_sub_losses: 2213.6910, paired_loss: 308.8031, class_loss: 274.5851, multi_bulk.rna.recon_loss: 1378.4247, multi_bulk.rna.var_loss: 0.0000, multi_bulk.rna.loss: 1378.4247, img.img.recon_loss: 835.2663, img.img.var_loss: 0.0000, img.img.loss: 835.2663, clf_loss: 1.3772
split: valid, n_samples: 323
Epoch 1/1 - Valid Loss: 2774.8009
Sub-losses - adver_loss: 14.3045, aggregated_sub_losses: 2239.8355, paired_loss: 265.7470, class_loss: 254.9139, multi_bulk.rna.recon_loss: 1489.2041, multi_bulk.rna.var_loss: 0.0000, multi_bulk.rna.loss: 1489.2041, img.img.recon_loss: 750.6314, img.img.var_loss: 0.0000, img.img.loss: 750.6314, clf_loss: 1.3286
Storing checkpoint for epoch 0...


STAGE:2025-12-05 09:00:48 1697:14584 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-12-05 09:00:48 1697:14584 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        11.84%        1.349s        89.67%       10.220s       10.220s         288 b    -120.58 Mb             1  
                                       aten::batch_norm        -0.01%    -945.000us        31.53%        3.593s      15.098ms           0 b           0 b           238  
                           aten::_batch_norm_impl_index         0.03%       3.256ms        31.52%        3.593s      15.096ms           0 b           

In [13]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  Source Location                                                              
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------------------------------------------------------------------------  
                                        model_inference        11.84%        1.349s        89.67%       10.220s       10.220s         288 b    -120.58 Mb             1  ...in method _record_function_enter_new of PyCapsule object at 0x11b10dce0>  
            

In [14]:
        
from lightning_fabric import Fabric
fabric = Fabric(
            accelerator="auto",
            devices=1
        )

In [15]:
fabric.device.type

'mps'

In [16]:
import torch
torch.randn(1,16).shape

torch.Size([1, 16])

In [None]:
for k, d in loader.dataset.datasets.items():
    print(d.get_input_dim())
    torch.randn(d.get_input_dim()).shape

    print(k)

1000
multi_bulk.rna
torch.Size([1, 32, 32])
img.img


In [22]:
for batch in loader:
    print(batch)
    break

{'multi_bulk.rna': {'data': tensor([[ 0.4283, -0.0834,  0.1322,  ..., -0.7598, -0.1700, -0.3983],
        [-0.4385, -0.2784, -0.4571,  ..., -0.3549, -0.2840,  2.1008],
        [-0.3912, -0.2749,  0.1042,  ..., -0.0360, -0.2184, -0.4030],
        ...,
        [-0.4310, -0.2784,  4.7178,  ...,  0.2996,  0.2801, -0.3302],
        [-0.4541, -0.2784, -0.7309,  ..., -0.4977,  0.1662,  0.7841],
        [-0.2006, -0.2784, -0.6513,  ..., -0.5386,  0.1336,  1.2173]],
       device='mps:0'), 'sample_ids': [np.str_('TCGA-86-6562-01'), np.str_('TCGA-D1-A165-01'), np.str_('TCGA-BH-A0GY-01'), np.str_('TCGA-G5-6641-01'), np.str_('TCGA-77-7140-01'), np.str_('TCGA-G4-6321-01'), np.str_('TCGA-BG-A222-01'), np.str_('TCGA-29-1696-01'), np.str_('TCGA-AG-3887-01'), np.str_('TCGA-25-1871-01'), np.str_('TCGA-BH-A0DT-01'), np.str_('TCGA-55-6979-01'), np.str_('TCGA-33-A5GW-01'), np.str_('TCGA-BG-A18B-01'), np.str_('TCGA-73-4668-01'), np.str_('TCGA-AA-3560-01'), np.str_('TCGA-AA-3524-01'), np.str_('TCGA-AH-6644-0