## Profiling and Benchmarking XModalix
XModalix is rather slow compared to Varix, we assume this is because of the MultiModalDataSet and the CustomSampler. 
Before improving the code we profile and benchmark the components of the XModalix pipeline. We think the following modules could be relevant:
- MultiModalDataset
  - which is comprised of NumericDataset, or ImageDataset
- CoverageEnsuringSampler
- XModlixTrainer
  - which calls multiple GeneralTrainers
    - which is a child of BaseTrainer


We should also profile for four different cases:
  - single cell vs standard tabular
  - paired vs unpaired
  - image vs standard tabulas
  - image vs single cell

In [1]:
import os

p = os.getcwd()
d = "autoencodix_package"
if d not in p:
    raise FileNotFoundError(f"'{d}' not found in path: {p}")
os.chdir(os.sep.join(p.split(os.sep)[: p.split(os.sep).index(d) + 1]))
print(f"Changed to: {os.getcwd()}")


Changed to: /Users/maximilianjoas/development/autoencodix_package


In [2]:
import torch
import numpy as np
import pandas as pd
from torch import nn
from torch.profiler import profile, ProfilerActivity, record_function
import autoencodix as acx
from autoencodix.trainers import _xmodal_trainer, _general_trainer
from autoencodix.base import BaseTrainer, BaseDataset
from autoencodix.data import NumericDataset, MultiModalDataset, ImageDataset
from autoencodix.data._multimodal_dataset import CoverageEnsuringSampler
from autoencodix.utils.example_data import EXAMPLE_MULTI_SC, EXAMPLE_MULTI_BULK
import torch.utils.benchmark as benchmark



In [3]:
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

#### Run XModalix to get access to all attributes we want to benchmark

In [None]:
from autoencodix.configs.xmodalix_config import XModalixConfig
from autoencodix.configs.default_config import DataConfig, DataInfo, DataCase

clin_file = os.path.join("./data/XModalix-Tut-data/combined_clin_formatted.parquet")
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

xmodalix_config = XModalixConfig(
    checkpoint_interval=100,
    class_param="CANCER_TYPE",
    epochs=1,
    beta=0.1,
    gamma=10,
    delta_class=100,
    delta_pair=300,
    latent_dim=6,
    k_filter=1000,
    batch_size=512,
    learning_rate=0.0005,
    requires_paired=False,
    device="mps",
    loss_reduction="sum",
    data_case=DataCase.IMG_TO_BULK,
    data_config=DataConfig(
        data_info={
            "img": DataInfo(
                file_path=img_root,
                img_height_resize=32,
                img_width_resize=32,
                data_type="IMG",
                scaling="STANDARD",
                translate_direction="to",
                pretrain_epochs=0,
            ),
            "rna": DataInfo(
                file_path=rna_file,
                data_type="NUMERIC",
                scaling="STANDARD",
                pretrain_epochs=0,
                translate_direction="from",
            ),
            "anno": DataInfo(file_path=clin_file, data_type="ANNOTATION", sep="\t"),
        },
        annotation_columns=["CANCER_TYPE_ACRONYM"],
    ),
)

xmodalix = acx.XModalix(config=xmodalix_config)
result = xmodalix.run()


reading parquet: data/XModalix-Tut-data/combined_rnaseq_formatted.parquet
reading parquet: ./data/XModalix-Tut-data/combined_clin_formatted.parquet
Given image size is possible, rescaling images to: 32x32
Successfully loaded 3230 images for img
anno key: paired
anno key: img
Converting 2261 images to torch.float32 tensors...
Converting 646 images to torch.float32 tensors...
Converting 323 images to torch.float32 tensors...
key: train, type: <class 'dict'>
key: valid, type: <class 'dict'>
key: test, type: <class 'dict'>
Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 0
No pretraining for multi_bulk.rna
Check if we need to pretrain: img.img
pretrain epochs : 0
No pretraining for img.img
--- Epoch 1/1 ---
split: train, n_samples: 2048.0
Epoch 1/1 - Train Loss: 3305.5685
Sub-losses - adver_loss: 14.4374, aggregated_sub_losses: 2571.7772, paired_loss: 424.5181, class_loss: 294.8359, multi_bulk.rna.recon_loss: 1177.7159, multi_bulk.rna.var_loss: 0.0000, multi_bulk.rna.loss: 11

#### Getting attributes I want to profiler

In [6]:
from autoencodix.data._multimodal_dataset import create_multimodal_collate_fn, CoverageEnsuringSampler
model = result.model
forward_fn = xmodalix._trainer._modalities_forward
loader = xmodalix._trainer._trainloader
dataset: MultiModalDataset = CoverageEnsuringSampler(multimodal_dataset=loader.dataset, batch_size=xmodalix_config.batch_size)
sampler = xmodalix._trainer
collate_fn = create_multimodal_collate_fn(multimodal_dataset=dataset)


In [7]:
activities = [ProfilerActivity.CPU]
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
    activities += [ProfilerActivity.CUDA]


sort_by_keyword = device + "_time_total"

with profile(activities=activities, record_shapes=True, profile_memory=True, with_stack=False,
    experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
    with record_function("model_inference"):
        
        xmodalix.fit()
        #forward_fn(next(iter(loader)))

print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))

STAGE:2025-12-04 16:02:42 85278:25291824 ActivityProfilerController.cpp:314] Completed Stage: Warm Up


Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 0
No pretraining for multi_bulk.rna
Check if we need to pretrain: img.img
pretrain epochs : 0
No pretraining for img.img
--- Epoch 1/1 ---
split: train, n_samples: 2048.0
Epoch 1/1 - Train Loss: 3270.3477
Sub-losses - adver_loss: 15.2660, aggregated_sub_losses: 2497.1506, paired_loss: 449.6098, class_loss: 308.3213, multi_bulk.rna.recon_loss: 1185.0723, multi_bulk.rna.var_loss: 0.0000, multi_bulk.rna.loss: 1185.0723, img.img.recon_loss: 1312.0783, img.img.var_loss: 0.0000, img.img.loss: 1312.0783, clf_loss: 1.5220
split: valid, n_samples: 323
Epoch 1/1 - Valid Loss: 2676.2175
Sub-losses - adver_loss: 14.6965, aggregated_sub_losses: 2042.6459, paired_loss: 366.8947, class_loss: 251.9804, multi_bulk.rna.recon_loss: 1018.4878, multi_bulk.rna.var_loss: 0.0000, multi_bulk.rna.loss: 1018.4878, img.img.recon_loss: 1024.1581, img.img.var_loss: 0.0000, img.img.loss: 1024.1581, clf_loss: 1.4365
Storing checkpoint for epoch 0...


STAGE:2025-12-04 16:02:42 85278:25291824 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-12-04 16:02:42 85278:25291824 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        10.02%      84.782ms        96.31%     815.085ms     815.085ms         288 b     -44.00 Mb             1  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        15.12%     127.966ms        24.10%     203.973ms      29.139ms      18.31 Mb     -35.34 Kb             7  
                                             aten::mean        15.84%     134.022ms        15.84%     134.022ms       3.622ms           0 b           

In [8]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        10.02%      84.782ms        96.31%     815.085ms     815.085ms         288 b     -44.00 Mb             1  
enumerate(DataLoader)#_SingleProcessDataLoaderIter._...        15.12%     127.966ms        24.10%     203.973ms      29.139ms      18.31 Mb     -35.34 Kb             7  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------