## Profiling and Benchmarking XModalix
XModalix is rather slow compared to Varix, we assume this is because of the MultiModalDataSet and the CustomSampler. 
Before improving the code we profile and benchmark the components of the XModalix pipeline. We think the following modules could be relevant:
- MultiModalDataset
  - which is comprised of NumericDataset, or ImageDataset
- CoverageEnsuringSampler
- XModlixTrainer
  - which calls multiple GeneralTrainers
    - which is a child of BaseTrainer


We should also profile for four different cases:
  - single cell vs standard tabular
  - paired vs unpaired
  - image vs standard tabulas
  - image vs single cell

In [1]:
import os

p = os.getcwd()
d = "autoencodix_package"
if d not in p:
    raise FileNotFoundError(f"'{d}' not found in path: {p}")
os.chdir(os.sep.join(p.split(os.sep)[: p.split(os.sep).index(d) + 1]))
print(f"Changed to: {os.getcwd()}")


Changed to: /Users/maximilianjoas/development/autoencodix_package


In [2]:
import torch
import numpy as np
import pandas as pd
from torch import nn
from torch.profiler import profile, ProfilerActivity, record_function
import autoencodix as acx
from autoencodix.trainers import _xmodal_trainer, _general_trainer
from autoencodix.base import BaseTrainer, BaseDataset
from autoencodix.data import NumericDataset, MultiModalDataset, ImageDataset
from autoencodix.data._multimodal_dataset import CoverageEnsuringSampler
from autoencodix.utils.example_data import EXAMPLE_MULTI_SC, EXAMPLE_MULTI_BULK
import torch.utils.benchmark as benchmark



In [3]:
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

#### Run XModalix to get access to all attributes we want to benchmark

In [4]:
from autoencodix.configs.xmodalix_config import XModalixConfig
from autoencodix.configs.default_config import DataConfig, DataInfo, DataCase
from autoencodix.modeling._imgfast_architecture import ImageVAEArchitecture

clin_file = os.path.join("./data/XModalix-Tut-data/combined_clin_formatted.parquet")
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

xmodalix_config = XModalixConfig(
    checkpoint_interval=100,
    class_param="CANCER_TYPE",
    epochs=3,
    beta=0.1,
    gamma=10,
    delta_class=100,
    delta_pair=300,
    latent_dim=6,
    k_filter=1000,
    batch_size=512,
    learning_rate=0.0005,
    requires_paired=False,
    device="cpu",
    loss_reduction="sum",
    data_case=DataCase.IMG_TO_BULK,
    data_config=DataConfig(
        data_info={
            # "img": DataInfo(
            #     file_path=img_root,
            #     img_height_resize=32,
            #     img_width_resize=32,
            #     data_type="IMG",
            #     scaling="STANDARD",
            #     translate_direction="to",
            #     pretrain_epochs=0,
            # ),
            "rna": DataInfo(
                file_path=rna_file,
                data_type="NUMERIC",
                scaling="STANDARD",
                pretrain_epochs=0,
                translate_direction="from",
            ),
            "rna2": DataInfo(
                file_path=rna_file,
                data_type="NUMERIC",
                scaling="STANDARD",
                pretrain_epochs=0,
                translate_direction="to",
            ),
 
            "anno": DataInfo(file_path=clin_file, data_type="ANNOTATION", sep="\t"),
        },
        annotation_columns=["CANCER_TYPE_ACRONYM"],
    ),
)

xmodalix = acx.XModalix(config=xmodalix_config, model_type=ImageVAEArchitecture)
result = xmodalix.run()


reading parquet: data/XModalix-Tut-data/combined_rnaseq_formatted.parquet
reading parquet: data/XModalix-Tut-data/combined_rnaseq_formatted.parquet
reading parquet: ./data/XModalix-Tut-data/combined_clin_formatted.parquet
anno key: rna
anno key: rna2
key: train, type: <class 'dict'>
key: valid, type: <class 'dict'>
key: test, type: <class 'dict'>
Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 0
No pretraining for multi_bulk.rna
Check if we need to pretrain: multi_bulk.rna2
pretrain epochs : 0
No pretraining for multi_bulk.rna2
--- Epoch 1/3 ---


  warn("CUDA is not available, disabling CUDA profiling")
STAGE:2025-12-05 14:15:32 50262:556503 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
[W CPUAllocator.cpp:249] Memory block of unknown size was allocated before the profiling started, profiler results will not include the deallocation event
STAGE:2025-12-05 14:15:32 50262:556503 ActivityProfilerController.cpp:320] Completed Stage: Collection
STAGE:2025-12-05 14:15:32 50262:556503 ActivityProfilerController.cpp:324] Completed Stage: Post Processing



PROFILER SUMMARY - Top Operations by CPU Time
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                          ProfilerStep*        22.46%      18.707ms       100.00%      83.276ms      27.759ms       7.01 Mb      -1.95 Mb             3  
                                            total_batch         0.26%     214.000us        58.65%      48.843ms      24.422ms       5.05 Mb           0 b             2  
                                     modalities_forward         4.68%       3.894ms        22.27%      

#### Getting attributes I want to profiler

In [5]:
from autoencodix.data._multimodal_dataset import create_multimodal_collate_fn, CoverageEnsuringSampler
model = result.model
forward_fn = xmodalix._trainer._modalities_forward
loader = xmodalix._trainer._trainloader
dataset: MultiModalDataset = CoverageEnsuringSampler(multimodal_dataset=loader.dataset, batch_size=xmodalix_config.batch_size)
sampler = xmodalix._trainer
collate_fn = create_multimodal_collate_fn(multimodal_dataset=dataset)


In [6]:
activities = [ProfilerActivity.CPU]
device = "cpu"
if torch.cuda.is_available():
    device = "cuda"
    activities += [ProfilerActivity.CUDA]


sort_by_keyword = device + "_time_total"


with profile(activities=activities, record_shapes=True, profile_memory=True, with_stack=False,
    experimental_config=torch._C._profiler._ExperimentalConfig(verbose=True)) as prof:
    with record_function("model_inference"):
        xmodalix.fit()
        #forward_fn(next(iter(loader)))

print(prof.key_averages().table(sort_by=sort_by_keyword, row_limit=10))

STAGE:2025-12-05 14:15:34 50262:556503 ActivityProfilerController.cpp:314] Completed Stage: Warm Up
  warn("CUDA is not available, disabling CUDA profiling")
STAGE:2025-12-05 14:15:34 50262:556503 ActivityProfilerController.cpp:320] Completed Stage: Collection


Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 0
No pretraining for multi_bulk.rna
Check if we need to pretrain: multi_bulk.rna2
pretrain epochs : 0
No pretraining for multi_bulk.rna2
--- Epoch 1/3 ---


STAGE:2025-12-05 14:15:34 50262:556503 ActivityProfilerController.cpp:324] Completed Stage: Post Processing


RuntimeError: Can't disable Kineto profiler when it's not running

In [None]:
print(prof.key_averages(group_by_stack_n=5).table(sort_by=sort_by_keyword, row_limit=2))

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg       CPU Mem  Self CPU Mem    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                        model_inference        15.49%        2.889s        86.51%       16.135s       16.135s         288 b    -238.10 Mb             1  
                                       aten::batch_norm         0.00%     693.000us        29.43%        5.488s       7.687ms           0 b           0 b           714  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ---------