# How To Use XModalix

X-Modalix is our implementation of a cross-modal autoencoder.  
This tutorial follows the structure of `Getting Started - Vanillix`, but is less extensive because our pipeline works similarly across different architectures. Here, we focus only on X-Modalix specifics.

**AUTOENCODIX** supports far more functionality than shown here, so weâ€™ll also point to advanced tutorials where relevant.  

**IMPORTANT**

> This tutorial only shows the specifics of the XModalix pipeline. If you're unfamiliar with general concepts,  
> we recommend following the `Getting Started - Vanillix` tutorial first.

## XModalix Theory

The core idea of a cross-modal autoencoder is to align the latent spaces of different data modalities.  
Once the latent distributions are aligned, we can translate between modalities by feeding data from one modality into its encoder and then passing the latent space to the decoder of the target modality.

The underlying models are separate VAEsâ€”one for each data modalityâ€”which are trained in parallel. Their latent spaces are closely aligned so that modalities are hard to discriminate, but sample variation is preserved.  
To achieve this alignment, the VAE loss function is extended with the following terms:

- *Adversarial* loss term: In parallel with the VAEs, a third neural network (latent space classifier) is trained to discriminate the embeddings of different modalities. Analogous to a GAN, the inverse of this classifier loss is added to the VAE loss to enforce latent space alignment.  

- *Paired* loss term: Alignment is further enforced by minimizing the distance between samples with paired measurements across modalities. While unpaired samples are allowed, we recommend that most samples are paired (present in all modalities).  

- *Class-based* loss term: This semi-supervised term minimizes distances among samples of the same class. Class membership is defined in the annotation data (e.g., cancer type, cell type). In each iteration, the class mean in the latent space is computed, and distances from individual samples to the class mean are minimized. See our `Preprint [TODO]` for more details.  

<img src="https://raw.githubusercontent.com/jan-forest/autoencodix/729f3e8172a722e7c3681298719d9d4b59e313d9/images/xmodalix_scheme.svg" alt="xmodalix_scheme" width="1200"/>

## What You'll Learn

As a showcase for data modality translation with X-Modalix, we will use cancer gene expression from TCGA in combination with handwritten digits from the [MNIST dataset](https://keras.io/api/datasets/mnist/).  
Our goal is to translate gene expression signatures of five selected cancer subtypes into images of digits, with each cancer subtype class mapped to a digit between 0â€“4.  

In practice, these images could be histopathological images or any other data modality.  
Before we show data preparation and X-Modalix training, we provide background on cross-modal VAEs as proposed by [Yang & Uhler](https://arxiv.org/abs/1902.03515).  

While following this showcase, you will learn how to:

1. **Initialize** the pipeline and run it. <br><br>
2. Understand the X-Modalix-specific **pipeline steps** (paired vs. unpaired data). <br><br>
3. Access X-Modalix-specific **results** (sub-results for modality autoencoders). <br><br>
4. **Visualize** and **evaluate** the outputs. <br><br>
5. Apply **custom parameters**. <br><br>
6. **Save, load, and reuse** a trained pipeline. <br><br>

Letâ€™s get started! ðŸš€


## 1) Initialize and Run XModalix

In this example, we read our input data from files.  
The file locations are defined in the config via a section called `DataConfig` (see details in the Python code below).  

We highlight a few custom config parameters for XModalix. For a deeper dive into the config object, see [1]:

- `pretrain_epochs`: Before training the full X-Modalix, the sub-VAEs of each modality can be pretrained. This parameter can be set globally for all modalities or individually per modality and specifies the number of pretraining epochs.  
- `gamma`: Hyperparameter to weight the adversarial loss term.  
- `delta_pair`: Hyperparameter to weight the paired loss term.  
- `delta_class`: Hyperparameter to weight the class-based loss term.  

### Extra 1: Getting Tutorial Data

You can run the following bash commands to download the data and set up the correct folder structure.  
**Assumption:** you are in the root of `autoencodix_package`.

```bash
mkdir -p data
cd data
wget "https://cloud.scadsai.uni-leipzig.de/index.php/s/bq64MaQyZGZfN64/download/XModalix-Tut-data.zip"
unzip XModalix-Tut-data.zip -d XModalix-Tut-data
```

### Extra 2: Get correct path
We assume you are in the root of the package. The following code ensures that the correct paths are used.
[1] Tutorials/DeepDives/ConfigTutorial.ipynb

In [3]:
import os

p = os.getcwd()
d = "autoencodix_package"
if d not in p:
    raise FileNotFoundError(f"'{d}' not found in path: {p}")
os.chdir(os.sep.join(p.split(os.sep)[: p.split(os.sep).index(d) + 1]))
print(f"Changed to: {os.getcwd()}")


Changed to: /Users/maximilianjoas/development/autoencodix_package


### Extra 3: Data Assumptions

X-Modalix supports two special cases:  
(a) image data, and  
(b) unpaired data.  

This introduces specific requirements for the input data. In the config object, the `DataConfig` must be populated with the relevant information. Example code:

```python
imganno_file = os.path.join("data/XModalix-Tut-data/tcga_mappings.txt")
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")
clin_file = os.path.join("data/XModalix-Tut-data/combined_clin_formatted.parquet")

dc = DataConfig(
    data_info={
        "img": DataInfo(
            file_path=img_root,
            data_type="IMG",
            scaling="MINMAX",
            translate_direction="to",
            pretrain_epochs=3,
            # extra_anno_file=imganno_file,
        ),
        "rna": DataInfo(
            file_path=rna_file,
            data_type="NUMERIC",
            scaling="MINMAX",
            translate_direction="from",
        ),
        "anno": DataInfo(
            file_path=clin_file,
            data_type="ANNOTATION",
            sep="\t"
        ),
    }
)
```
**Image Data Requirements**

- Specify the root folder where the images are located. The preprocessor will automatically read all images and convert them into tensors.  
- Supported image extensions (case-insensitive): `".jpg"`, `".jpeg"`, `".png"`, `".tif"`, `".tiff"`.  
- A mapping between each image filename (without directories) and a `sample_id` is required, along with any additional metadata.  

To provide this mapping, you must supply a **global annotation file** that contains all `sample_id`s and metadata columns, including one column with the image paths.  
- Use the config parameter `img_path_col` to define the name of the column containing image paths (default: `"img_paths"`).  
- See your `clin_file` for an example.

**Translation Direction**

- X-Modalix trains a shared latent alignment, enabling translation between all modalities.  
- To keep results consistent, users should define a `translate_direction` in the config.  
- If you want to change the direction or translate a different pair, you can specify this in the `predict` step (see [Section 3](#3-access-xmodalix-results)).


In [4]:
import os
import pandas as pd
import autoencodix as acx

from autoencodix.configs.xmodalix_config import XModalixConfig
from autoencodix.configs.default_config import DataConfig, DataInfo, DataCase

clin_file = os.path.join("data/XModalix-Tut-data/combined_clin_formatted.parquet")
clin_df = pd.read_parquet(clin_file)
print("First five image paths:")
print(clin_df["img_paths"].to_list()[0:5])
print("\n")
clin_df.head()


First five image paths:
['2331_label_0.png', '2332_label_0.png', '2333_label_0.png', '2334_label_0.png', '2335_label_0.png']




Unnamed: 0_level_0,ONCOTREE_CODE,CANCER_TYPE,CANCER_TYPE_DETAILED,TUMOR_TYPE,GRADE,TISSUE_PROSPECTIVE_COLLECTION_INDICATOR,TISSUE_RETROSPECTIVE_COLLECTION_INDICATOR,TISSUE_SOURCE_SITE_CODE,TUMOR_TISSUE_SITE,ANEUPLOIDY_SCORE,...,OS_STATUS,OS_MONTHS,DSS_STATUS,DSS_MONTHS,DFS_STATUS,DFS_MONTHS,PFS_STATUS,PFS_MONTHS,GENETIC_ANCESTRY_LABEL,img_paths
SAMPLE_ID,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1,Unnamed: 4_level_1,Unnamed: 5_level_1,Unnamed: 6_level_1,Unnamed: 7_level_1,Unnamed: 8_level_1,Unnamed: 9_level_1,Unnamed: 10_level_1,Unnamed: 11_level_1,Unnamed: 12_level_1,Unnamed: 13_level_1,Unnamed: 14_level_1,Unnamed: 15_level_1,Unnamed: 16_level_1,Unnamed: 17_level_1,Unnamed: 18_level_1,Unnamed: 19_level_1,Unnamed: 20_level_1,Unnamed: 21_level_1
TCGA-3C-AAAU-01,ILC,Breast Cancer,Breast Invasive Lobular Carcinoma,Infiltrating Lobular Carcinoma,unknown,No,Yes,3C,Breast,19.0,...,0:LIVING,133.050597,0:ALIVE OR DEAD TUMOR FREE,133.050597,1:Recurred/Progressed,59.440444,1:PROGRESSION,59.440444,EUR,2331_label_0.png
TCGA-3C-AALI-01,IDC,Breast Cancer,Breast Invasive Ductal Carcinoma,Infiltrating Ductal Carcinoma,unknown,No,Yes,3C,Breast,22.0,...,0:LIVING,131.66979,0:ALIVE OR DEAD TUMOR FREE,131.66979,0:DiseaseFree,131.66979,0:CENSORED,131.66979,AFR,2332_label_0.png
TCGA-3C-AALJ-01,IDC,Breast Cancer,Breast Invasive Ductal Carcinoma,Infiltrating Ductal Carcinoma,unknown,No,Yes,3C,Breast,13.0,...,0:LIVING,48.459743,0:ALIVE OR DEAD TUMOR FREE,48.459743,0:DiseaseFree,48.459743,0:CENSORED,48.459743,AFR_ADMIX,2333_label_0.png
TCGA-3C-AALK-01,IDC,Breast Cancer,Breast Invasive Ductal Carcinoma,Infiltrating Ductal Carcinoma,unknown,No,Yes,3C,Breast,4.0,...,0:LIVING,47.604958,0:ALIVE OR DEAD TUMOR FREE,47.604958,unknown,,0:CENSORED,47.604958,AFR,2334_label_0.png
TCGA-4H-AAAK-01,ILC,Breast Cancer,Breast Invasive Lobular Carcinoma,Infiltrating Lobular Carcinoma,unknown,Yes,No,4H,Breast,7.0,...,0:LIVING,11.440971,0:ALIVE OR DEAD TUMOR FREE,11.440971,0:DiseaseFree,11.440971,0:CENSORED,11.440971,EUR,2335_label_0.png


Another way is to provide a image-specific annotation file in with the `extra_anno_file` parameter of the DataInfo object in the DataConfig.
Here you need to map the sample_id (index) to the image path and can add addtional metadata columns like we did in our `imganno_file`.   
> We recommend the first option of a global annotation file.


In [5]:
imganno_file = os.path.join("data/XModalix-Tut-data/tcga_mappings.txt")
img_anno_df = pd.read_csv(imganno_file, sep="\t", index_col=0)
img_anno_df.head()

Unnamed: 0_level_0,img_paths,extra_class_labels,CANCER_TYPE_ACRONYM
sample_ids,Unnamed: 1_level_1,Unnamed: 2_level_1,Unnamed: 3_level_1
TCGA-05-4244-01,0_label_1.png,Non-Small Cell Lung Cancer,LUAD
TCGA-05-4249-01,1_label_1.png,Non-Small Cell Lung Cancer,LUAD
TCGA-05-4250-01,2_label_1.png,Non-Small Cell Lung Cancer,LUAD
TCGA-05-4382-01,3_label_1.png,Non-Small Cell Lung Cancer,LUAD
TCGA-05-4384-01,4_label_1.png,Non-Small Cell Lung Cancer,LUAD


## 1) Initialize and Run XModalix

In this section, we demonstrate how to set up the configuration to account for the data requirements and highlight the custom XModalix parameters.  

For a more in-depth guide on the configuration parameters and input data, see:  

[1] `Tutorials/DeepDives/ConfigTutorial.ipynb`  
[2] `Tutorials/DeepDives/InputDataTutorials.ipynb`


In [6]:
rna_file = os.path.join("data/XModalix-Tut-data/combined_rnaseq_formatted.parquet")
img_root = os.path.join("data/XModalix-Tut-data/images/tcga_fake")

xmodalix_config = XModalixConfig(
    checkpoint_interval=5,
    class_param="CANCER_TYPE_ACRONYM",
    epochs=30,
    beta=0.01,
    gamma=2,
    delta_class=2.0,
    delta_pair=3.0,
    latent_dim=8,
    requires_paired=False,
    pretrain_epochs=20,
    data_case=DataCase.IMG_TO_BULK,
    data_config=DataConfig(
        data_info={
            "img": DataInfo(
                file_path=img_root,
                data_type="IMG",
                scaling="MINMAX",
                translate_direction="to",
                # extra_anno_file=imganno_file,
            ),
            "rna": DataInfo(
                file_path=rna_file,
                data_type="NUMERIC",
                scaling="MINMAX",
                translate_direction="from",
            ),
            "anno": DataInfo(file_path=clin_file, data_type="ANNOTATION", sep="\t"),
        },
    ),
)

xmodalix = acx.XModalix(config=xmodalix_config)
result = xmodalix.run()

reading parquet: data/XModalix-Tut-data/combined_rnaseq_formatted.parquet
reading parquet: data/XModalix-Tut-data/combined_clin_formatted.parquet
Given image size is possible, rescaling images to: 64x64
Successfully loaded 3230 images for img
anno key: rna
anno key: img
Converting 2261 images to torch.float32 tensors...
Converting 646 images to torch.float32 tensors...
Converting 323 images to torch.float32 tensors...
key: train, type: <class 'dict'>
key: valid, type: <class 'dict'>
key: test, type: <class 'dict'>
Check if we need to pretrain: multi_bulk.rna
pretrain epochs : 20
Starting Pretraining for: multi_bulk.rna with <class 'autoencodix.trainers._general_trainer.GeneralTrainer'>
Epoch 1 - Train Loss: 4.4723
Sub-losses: recon_loss: 4.4723, var_loss: 0.0000, anneal_factor: 0.0000, effective_beta_factor: 0.0000
Epoch 1 - Valid Loss: 1.4290
Sub-losses: recon_loss: 1.4290, var_loss: 0.0000, anneal_factor: 0.0000, effective_beta_factor: 0.0000
Epoch 2 - Train Loss: 1.6831
Sub-losses: 

In [None]:
xmodalix.predict()

## 2) XModalix Specific Steps

XModalix does not introduce additional pipeline steps. However, its preprocessing and training differ slightly from Varix and Vanillix. XModalix processes each data modality individually and therefore supports unpaired data input.  

It is also possible to re-run the `predict` step and change the translation pairs or direction by providing the `from_key` and `to_key` arguments with the string keys that define your data modalities (e.g., `"rna"` and `"img"` in our case). See the code below for details.

Furthermore, the visualization and evaluation steps produce outputs that differ somewhat from the other pipelines, which we will focus on in [Section 4](#4-visualize-and-evaluate-xmodalix).


In [None]:
# Initially, we translated from RNA to image. Now we can use the trained model
# to re-run the translation with switched pairs.
# ATTENTION: This will overwrite the existing result object, so we create a backup first.
from copy import deepcopy

backup_result = deepcopy(result)
result_flipped = xmodalix.predict(from_key="img", to_key="rna")


## 3) Access XModalix Results

Accessing the results works slightly differently for `XModalix` because we have:

- Multiple latent spaces (one for each data modality).

- Not just reconstructions, but also `translations` and `reference translations`.  
  *Reference translations* are reconstructions within the same data modality as the translation split.  
  For example, if we translate from `rna` to `img`, the reference translation is the reconstruction from `img` to `img`.  
  These translations can be accessed via the `reconstructions` attribute of the `result` object.  
  First, use the usual `TrainingDynamics` API via `.get()`, which returns a dict containing each data modalityâ€™s translation and reference translation.

- Multiple losses.  
  These are accessed via the `sub_losses` attribute of the `result` object. This is a `Dict` of `TrainingDynamics`.  
  First, select the loss type youâ€™re interested in, then work with the standard `TrainingDynamics` API.

See the code below for examples and [3] for more details:

[3] `Tutorials/DeepDives/PipelineOutputTutorial.ipynb`


In [None]:
print("Find available loss types")
print(backup_result.sub_losses.keys())
print("Get adver_loss for training")
backup_result.sub_losses.get("adver_loss").get(split="train")

In [None]:
print("Get reconstruction keys")
# Frist define split and epoch you're interested in
# usually test split (there are no epochs, so by default this is always epoch=-1)
recons = backup_result.reconstructions.get(split="test", epoch=-1)
print(recons.keys())
print("Getting Translation")
trans = recons.get("translation")
print(f"shape of translation: {trans.shape}")

**We can acutally plot a sample image here**

In [None]:
import matplotlib.pyplot as plt

sample = trans[0, 0, :, :]
plt.imshow(sample)
plt.show()

Accessing the latent spaces works analogously.

## 4) Visualize and Evaluate XModalix

XModalix offers a few additional visualizations, namely:

- An image grid (when using images) that compares original with translated data.  
- A visualization of each data modality's latent space, which can be colored based on a target variable.

Additionally, the `evaluate` step works differently for XModalix. While for other pipelines we evaluate embedding quality with downstream machine learning tasks, here we focus on how well biological information is preserved in the translated data modality. We typically plot the mean squared error (MSE) between the original data and the translated data, clustered by a metadata parameter (e.g., cancer type).


In [None]:
fig = xmodalix._visualizer.show_2D_translation(
	result=backup_result,
	translated_modality="img.img",
	split="test",
	param="CANCER_TYPE_ACRONYM",
	reducer="UMAP",
)

### ATTENTION
Eval needs fixing from JE

In [None]:
xmodalix.evaluate()
param = "early"
fig, df_comparison = xmodalix._evaluator.pure_vae_comparison(
	xmodalix_result = xmodalix.result,
	# pure_vae_result = imagix.result,
	to_key = "img.IMG",
	param = param
)

xmodalix._visualizer.plots["PureVAEComparison"][param] = fig
fig.figure


## 5) Apply Custom Parameters
Due to the more complex training, we have more config parameters to customize. Here we go over the different `loss terms` and `pretraining`.
```python
    gamma: float = Field(
        default=10.0,
        ge=0,
        description="Gamma weighting factor for Adversial Loss Term i.e. for XModal Classfier training",
    )
    delta_pair: float = Field(
        default=5.0,
        ge=0,
        description="Delta weighting factor for paired loss term in XModale Training",
    )
    delta_class: float = Field(
        default=5.0,
        ge=0,
        description="Delta weighting factor for class loss term in XModale Training",
    )
```
The correct parametrization of these terms depends heavily on your data. We provide sensible default values in our `XModalixConfig` class, which are a good starting point. You can investigate the loss plots to see how much each loss contributes. For example, if your goal is to strongly preserve class information in the latent embedding, you would set `delta_class` higher. If you care more about generalization, you might set `gamma` higher.

Another important factor is pretraining. Before training a modality within the full XModalix architecture, we can pretrain that modality in its subnetwork without the XModalix constraints (adversarial, paired, class loss). This can be controlled via the `pretrain_epochs` parameter, which can be set either globally for all data modalities or individually per modality. Both options are shown below.


In [None]:
import autoencodix as acx
from autoencodix.configs.xmodalix_config import XModalixConfig
from autoencodix.configs.default_config import DataConfig, DataInfo, DataCase


xmodalix_config = XModalixConfig(
    class_param="CANCER_TYPE_ACRONYM",
    gamma=2,
    delta_class=10.0,  # increase this if you care about class info in embedding
    delta_pair=3.0,
    requires_paired=False,
    pretrain_epochs=2,  # can be overridden on data level, see below
    data_case=DataCase.IMG_TO_BULK,
    data_config=DataConfig(
        data_info={
            "img": DataInfo(
                file_path="YOUR_PATH",
                data_type="IMG",
                translate_direction="to",
                pretrain_epochs=2,  # this overrides global pretraining_epochs
            ),
            "rna": DataInfo(
                file_path="YOUR_PATH",
                data_type="NUMERIC",
                translate_direction="from",
            ),
            "anno": DataInfo(file_path="YOUR_PATH", data_type="ANNOTATION", sep="\t"),
        },
    ),
)

## 6) Save, Load and Reuse XModalix

This works for `XModalix` as for any other models, by using the `save` and `load` methods.

In [None]:
xmodalix_loaded._trainer._fabric.device

In [None]:


from lightning_fabric.wrappers import _FabricModule
models = xmodalix_loaded.result.model
keys = models.keys()
print(keys)  # dict_keys(['img_encoder', 'img_decoder', 'rna_encoder',
img_model = models["img.img"]
rna_model = models["multi_bulk.rna"]   
fabric = xmodalix_loaded._trainer._fabric
print(f"type of img_model: {type(img_model)}")
print(f"type of rna_model: {type(rna_model)}")
print(isinstance(rna_model, _FabricModule))
img_model_setup = fabric.setup_module(img_model)
rna_model.to(fabric.device)

img_model.to(fabric.device)

In [None]:
import os
import glob

outpath = os.path.join("xmodalix")
xmodalix.save(file_path=outpath, save_all=False)

folder = os.path.dirname(outpath)
pkl_files = glob.glob(os.path.join(folder, "*.pkl"))
model_files = glob.glob(os.path.join(folder, "*.pth"))

print("PKL files:", pkl_files)
print("Model files:", model_files)

# the load functionality automatically will build the pipeline object out of the three saved files
xmodalix_loaded = acx.XModalix.load(outpath)
# now you can use the model to predict with a different pair again:
testdata = backup_result.datasets
r = xmodalix_loaded.predict(data=testdata)

In [None]:
xmodalix_loaded.result.model

In [None]:
xmodalix_loaded.show_result()