
# How To Use Maskix

Maskix is our implementation of a variational autoencoder.  
This tutorial follows the structure of our `Getting Started - Vanillix`, but is much less extensive, because  
our pipeline works similarly for different architectures, so here we focus only on Maskix specifics.

**AUTOENCODIX** supports far more functionality than shown here, so we’ll also point to advanced tutorials where relevant.  

**IMPORTANT**

> This tutorial only shows the specifics of the Maskix pipeline. If you're unfamiliar with general concepts,  
> we recommend following the `Getting Started - Vanillix` tutorial first.

## What You'll Learn

You’ll learn how to:

1. **Theory primer** of the Maskix architecture. <br><br>
2. **Initialize** the pipeline and run the pipeline. <br><br>
3. Understand Maskix-specific **config parameters*. <br><br>
4. **Access & Visualize** the results effectively. <br><br>
5. Pass **custom masking fucntions**. <br><br>
6. How to use Maskix for **data imputation** <br><br>
7. **Save, load, and reuse** a trained pipeline. <br><br>


## 1) Theory Primer

Maskix adapts the scMAE (single-cell Masked Autoencoder) framework from Fang et al. (2024) for single-cell RNA-seq analysis. The model learns useful representations by corrupting the input expression matrix and training the network both to reconstruct the original data and to identify which entries were perturbed. This encourages the model to capture gene–gene (or feature–feature) relationships in high-dimensional, noisy datasets. Although originally designed for single-cell data, the same approach applies to other domains.

Each training iteration follows this corruption process:

1. Sample a Bernoulli distribution to determine which entries should be perturbed, producing a binary mask with the same shape as the input.
2. For each gene (or feature), generate a random permutation of sample indices.
3. For all positions marked as masked, replace the original value with the value from the permuted index for that feature.

The corrupted input is encoded into a low-dimensional latent representation. In parallel, a mask predictor takes the latent space as input and estimates which positions were masked. The predicted mask is then concatenated with the latent representation and passed into a decoder that attempts to reconstruct the original values.

Training uses two loss components. The mask predictor is optimized with binary cross-entropy against the true mask. Reconstruction uses a weighted mean-squared error where masked entries receive higher weight, controlled by the hyperparameter `delta_mask_corrupted`. The total loss is the weighted sum of these components, balanced by the hyperparameter `delta_mask_predictor`.


## Requirements 1: Be in the correct directory (execute below)


In [None]:
import os

p = os.getcwd()
d = "autoencodix_package"
if d not in p:
    raise FileNotFoundError(f"'{d}' not found in path: {p}")
os.chdir(os.sep.join(p.split(os.sep)[: p.split(os.sep).index(d) + 1]))
print(f"Changed to: {os.getcwd()}")

## Requirements 2: Obtain tutorial data or use own data
We use a dataset from the [Fang et al. (2024)](https://doi.org/10.1093/bioinformatics/btae020), with can be donwloaded [here](https://cloud.scadsai.uni-leipzig.de/index.php/s/BP8YzDef4nSfgwj/download/GSE84133_human_combined_final.h5ad), this needs to be moved to `data/raw`. Alternatively you can use your own single-cell dataset (Maskix works with other datatypes, but in this example we expect single cell data: `h5ad`) file and replace the variable `sc_path`.


## 2) Initialize and Run Maskix
As for every other pipline, we need to perform the following steps:
- import relevant classes
- define a config
- init the pipeline
- call the run step

In [None]:
import autoencodix as acx
from autoencodix.configs import MaskixConfig
from autoencodix.configs.default_config import DataInfo, DataConfig, DataCase

sc_path = os.path.join("data/raw", "GSE84133_human_combined_final.h5ad")
config = MaskixConfig(
    epochs=30,
    checkpoint_interval=10,
    k_filter=1000,
    batch_size=64,
    data_config=DataConfig(
        annotation_columns=["multi_sc:assigned_cluster"],
        data_info={
            "multi_sc": DataInfo(
                file_path=sc_path, is_single_cell=True, data_type="NUMERIC"
            )
        },
    ),
    data_case=DataCase.MULTI_SINGLE_CELL,
)
maskix = acx.Maskix(config=config)

In [None]:
maskix_result = maskix.run()


## 3) Understanding Maskix-specific Config Parameters

As described in the [theory section](#1-theory-primer), we implemented the architecture from [Fang et al. (2024)](https://doi.org/10.1093/bioinformatics/btae020), with weighted loss terms and a specific encoder–decoder design. To make this adaptable within our framework, we expose the following configuration parameters:

* **maskix_hidden_dim**: Hidden dimension used in the Maskix encoder and decoder, matching the scMAE reference architecture by default.
* **maskix_swap_prob**: Bernoulli probability controlling how often feature values are swapped during input corruption.
* **delta_mask_predictor**: Weighting factor for the mask prediction loss in the total training objective.
* **delta_mask_corrupted**: Weighting factor that increases the reconstruction penalty on corrupted entries.
* **maskix_architecture**: Selects between the default scMAE architecture or a custom architecture configured through `n_layers` and `enc_factor`.

You can find the default values by running the following:


In [None]:
MaskixConfig.print_schema(
    filter_params=[
        "maskix_hidden_dim",
        "maskix_swap_prob",
        "delta_mask_predictor",
        "delta_mask_corrupted",
        "maskix_architecture",
    ]
)


## 4) Access & Visualize Results Effectively

In addition to the results that the Vanillix pipeline provided, we can access:
- `total`, `reconstruction`, and `masked` losses  

A note on the different loss types:  
For our maksed autoencoder, the total loss consists of a reconstruction loss and a mask predictor los
To investigate these losses, the `result` object has the attribute `sub_losses`.  
This is a `LossRegistry` with the name of the loss as the key, and the value is a `TrainingDynamics` object, which can be accessed in the same way as for the Vanillix results.

For more details, check `Tutorials/DeepDives/PipelineOutputTutorial.ipynb`.


In [None]:
sub_losses = maskix_result.sub_losses
print("Sub Losses:")
print(f"keys: {sub_losses.keys()}")
print("\n")
recon_dyn = sub_losses.get(key="recon_loss")
print("Value of reconstruction loss in epoch 4 for train split")
print(recon_dyn.get(split="train", epoch=4))

As for our other pipelines we can visualize the loss with `.show_result()`
For more infos on visualization, see: `Tutorials/DeepDives/VisualizeTutorial.ipynb`

In [None]:
maskix.show_result()

## 5) Adding a Custom Masking Function to `MaskixTrainer`

`MaskixTrainer` supports replacing its default corruption mechanism with a user-defined masking function. This enables experimentation with alternative masking strategies while ensuring compatibility with the trainer’s data flow.

#### How to Add a Custom Masking Function

Provide your masking function at initialization:

```python
# We assume that config is defined and other imports are done (see above)

def my_masking_fn(x: torch.Tensor, strength: float = 0.2):
    noise = torch.randn_like(x) * strength
    return x + noise  # must return ONLY a single tensor in shape of input tensor

masking_fn_kwargs = {"strength": 0.1}
maskix = acx.Maskix(config=config, masking_fn, masking_fn_kwargs
)
```
#### Requirements for a Custom Masking Function

A custom masking function must satisfy the following constraints:

1. **It must accept a `torch.Tensor` as the first positional argument.**  
   The trainer passes the input mini-batch `X` directly into the function.

2. **It must return exactly one value: a `torch.Tensor`.**  
   The trainer does not consume or propagate additional outputs.  
   Returning tuples or multiple values is not allowed.

3. **The returned tensor must have the same shape as the input tensor.**  
   Any shape mismatch will raise a validation error.

4. **The function must operate on the device of the input tensor.**  
   The function must not assume the tensor resides on the CPU; it must operate on the device of `x`.

5. **Any additional parameters must be passed via `masking_fn_kwargs`.**  
   These keyword arguments provide a clean separation between trainer configuration and masking logic.

##### Example
This is our default masking method:
```python
    def _maskix_hook(
        self, X: torch.Tensor
    ) -> torch.Tensor
        # expand probablities for bernoulli sampling to match input shape
        probs = self._mask_probas.expand(X.shape)

        # Create the Boolean Mask (1 = Swap, 0 = Keep)
        should_swap = torch.bernoulli(probs).bool()

        # COLUMN-WISE SHUFFLING
        # We generate a random float matrix and argsort it along dim=0.
        # This gives us independent random indices for every column.
        rand_indices = torch.rand(X.shape, device=X.device).argsort(dim=0)

        # Use gather to reorder X based on these random indices
        shuffled_X = torch.gather(X, 0, rand_indices)
        corrupted_X = torch.where(should_swap, shuffled_X, X)

        return corrupted_X
```



##### Code Example

In [None]:
import torch
import autoencodix as acx
from autoencodix.configs import MaskixConfig
from autoencodix.configs.default_config import DataInfo, DataConfig, DataCase


def my_masking_fn(x: torch.Tensor, strength: float = 0.2):
    # Noise is created with the same shape, dtype, and device as `x`
    # Because of randn_lie, if you use other function, take care of 
    # device and dtype casting.
    noise = torch.randn_like(x) * strength
    return x + noise


kwargs = {"strength": 0.5}


sc_path = os.path.join("data/raw", "GSE84133_human_combined_final.h5ad")
config = MaskixConfig(
    epochs=5,
    checkpoint_interval=2,
    batch_size=64,
    data_config=DataConfig(
        annotation_columns=["multi_sc:assigned_cluster"],
        data_info={
            "multi_sc": DataInfo(
                file_path=sc_path, is_single_cell=True, data_type="NUMERIC"
            )
        },
    ),
    data_case=DataCase.MULTI_SINGLE_CELL,
)
maskix = acx.Maskix(config=config, masking_fn=my_masking_fn, masking_fn_kwargs=kwargs)

In [None]:
result = maskix.run()

## 6) Use Maskix to Impute Data

You can also input corrupted/missing data and use `Maskix` to impute the data. Here we recommend using a custom masking function. For example, if you want to impute missing values, an imputer could randomly replace values with zeros.

Then you could use `Maskix` to clean your data and use the cleaned data to run your analysis, for example another autoencodix pipeline or anything else. In our mock example we will do the following:

- Create corrupted data with missing values
- Remove the "missing" data
- Train Maskix with clean data, but with a custom imputer that mimics missing data
- Feed corrupted data into trained Maskix and obtain imputed data
- Train Varix with:
    - Original data with missing values
    - Imputed data
- Compare results

#### Create Corrupted Data
We will use our single-cell example from before and use maskix to preprocess the data, which makes the artificall corruption process more robust, because we make sure to corrupt informative features/samples.

##

In [None]:
import torch
import autoencodix as acx
from autoencodix.configs import MaskixConfig
from autoencodix.configs.default_config import DataInfo, DataConfig, DataCase


sc_path = os.path.join("data/raw", "GSE84133_human_combined_final.h5ad")
config = MaskixConfig(
    epochs=30,
    checkpoint_interval=2,
    k_filter=1000,
    batch_size=64,
    data_config=DataConfig(
        annotation_columns=["multi_sc:assigned_cluster"],
        data_info={
            "multi_sc": DataInfo(
                file_path=sc_path, is_single_cell=True, data_type="NUMERIC"
            )
        },
    ),
    data_case=DataCase.MULTI_SINGLE_CELL,
)
maskix_orig = acx.Maskix(config=config)
maskix_orig.preprocess()
data = maskix_orig.result.datasets

Now we will randomly set data do zero with.

In [None]:

import torch
import copy
from autoencodix.data._numeric_dataset import NumericDataset
def drop_samples(ds: NumericDataset):
    """Randomly drops samples in .data and according .metadata (pd.DataFrame) and sample_ids"""
    data = ds.data
    n_samples = data.shape[0]
    drop_prob = 0.3
    keep_mask = torch.bernoulli((1 - drop_prob) * torch.ones(n_samples)).bool()
    # replace entries with zero
    imputed_data = data.clone()
    imputed_data[~keep_mask] = 0

    missing_data = data[keep_mask]
    missing_metadata = ds.metadata.iloc[keep_mask.cpu().numpy()].reset_index(drop=True)
    missing_sample_ids = [sid for i, sid in enumerate(ds.sample_ids) if keep_mask[i]]
    ds_with_missing = copy.deepcopy(ds)
    ds_with_missing.data = missing_data
    ds_with_missing.metadata = missing_metadata
    ds_with_missing.sample_ids = missing_sample_ids


    ds_with_zero = copy.deepcopy(ds)
    ds_with_zero.data = imputed_data
    ds_with_zero.metadata = ds.metadata
    ds_with_zero.sample_ids = ds.sample_ids
    return ds_with_missing, ds_with_zero

ds_with_missing = copy.deepcopy(data)
ds_with_zero = copy.deepcopy(data)
missing_train, ds_with_zero_train = drop_samples(ds_with_missing.train)

print(f"Original train data shape: {data.train.data.shape}, shape after corruption: {missing_train.data.shape}")
missing_test, ds_with_zero_test = drop_samples(data.test)
print(f"Original test data shape: {data.test.data.shape}, shape after corruption: {missing_test.data.shape}")
missing_valid, ds_with_zero_valid    = drop_samples(data.valid)
print(f"Original valid data shape: {data.valid.data.shape}, shape after corruption: {missing_valid.data.shape}")
ds_with_missing.train= missing_train
ds_with_missing.test = missing_test
ds_with_missing.valid = missing_valid


ds_with_zero.train= ds_with_zero_train
ds_with_zero.test = ds_with_zero_test
ds_with_zero.valid = ds_with_zero_valid
# clean corruped samples by removing zero values from dataset (drop samples) and also 




Now we train our Maskix with this clean data, but we will pass a custom imputer that will simulate missing values

In [None]:
import torch
import autoencodix as acx
from autoencodix.configs import MaskixConfig
from autoencodix.configs.default_config import DataInfo, DataConfig, DataCase


def my_imputer(x: torch.Tensor) -> torch.Tensor:
    "randomly replaces value with zero"
    rand_mask = torch.bernoulli(0.3 * torch.ones(x.shape, device=x.device)).bool()
    rand_mask.to(x.device)
    imputed_x = torch.where(rand_mask, torch.zeros_like(x, device=x.device), x)
    return imputed_x


sc_path = os.path.join("data/raw", "GSE84133_human_combined_final.h5ad")
config = MaskixConfig(
    epochs=30,
    checkpoint_interval=2,
    k_filter=3000,
    skip_preprocessing=True,
    batch_size=64,
    data_config=DataConfig(
        annotation_columns=["multi_sc:assigned_cluster"],
        data_info={
            "multi_sc": DataInfo(
                 is_single_cell=True, data_type="NUMERIC"
            )
        },
    ),
    data_case=DataCase.MULTI_SINGLE_CELL,
)
maskix = acx.Maskix(config=config, masking_fn=my_imputer, data=ds_with_missing)
result = maskix.run()

No we use the trained maskix to impute our missing data

Now, we can use the fitted model and use a corrupted input with missing to get a reconstruction without missing values.


In [None]:
mo_train = maskix.impute(ds_with_zero.train.data)
recons_train = mo_train.reconstruction

mo_test = maskix.impute(ds_with_zero.test.data)
recons_test = mo_test.reconstruction

mo_valid = maskix.impute(ds_with_zero.valid.data)
recons_valid = mo_valid.reconstruction


ds_imputed = copy.deepcopy(ds_with_zero)
ds_imputed.train.data = recons_train
ds_imputed.test.data = recons_test
ds_imputed.valid.data = recons_valid

We can compare the loss between the imputed data and the original and the reconstructed data and the original.

In [None]:
%%capture
# ground truth without missing values
result_orig = maskix_orig.run()

In [None]:
from torch.nn.functional import mse_loss
original_train = result_orig.datasets.train.data
original_recon = result_orig.reconstructions.get(split="train", epoch=-1)
original_recon_tensor = torch.from_numpy(original_recon)
loss = mse_loss(original_recon_tensor, original_train)
print(f"MSE Loss original reconstruction: {loss}")
loss_imputed = mse_loss(recons_train.to("cpu"), original_train)
print(f"MSE Loss imputed reconstruction: {loss_imputed}")


Finally, you can use this data as input for other autoencoders like `varix` or `vanillix`, but also use it for maskix.
We will train two Varix models, one with the missing data and one with the imputed data and compare the results

In [None]:
original_train.shape

In [None]:
recons_train.shape

In [None]:
import autoencodix as acx
from autoencodix.configs import VarixConfig
from autoencodix.configs.default_config import DataInfo, DataConfig, DataCase
from autoencodix.data._datasetcontainer import DatasetContainer
import copy

config = VarixConfig(
    epochs=50,
    checkpoint_interval=10,
    batch_size=64,
    skip_preprocessing=True,
    data_config=DataConfig(
        annotation_columns=["multi_sc:assigned_cluster"],
        data_info={"multi_sc": DataInfo(is_single_cell=True, data_type="NUMERIC")},
    ),
    data_case=DataCase.MULTI_SINGLE_CELL,
)

varix_imputed = acx.Varix(config=config, data=ds_imputed)
varix_missing = acx.Varix(config=config, data=ds_with_zero)

In [None]:
result_imputed = varix_imputed.run()
result_missing = varix_missing.run()

In [None]:
varix_imputed.show_result()

In [None]:
varix_imputed.show_result()

In [None]:
varix_imputed.evaluate()

In [None]:
varix_missing.evaluate()

## 7) Save, Load and Re-Use Maskix

There are not `Maskix` specific steps here. See the `Tutorials/PipelineTutorials/Vanillix.ipynb
`  or `Tutorials/DeepDives/MemoryEfficientSaving.ipynb` for details. Below is a basic save/load usecase:

In [None]:

import os
import glob
# use a filename without extension, we handle this internally
outpath = os.path.join("tutorial_res", "maskix")
maskix.save(file_path=outpath, save_all=False)

folder = os.path.dirname(outpath)
pkl_files = glob.glob(os.path.join(folder, "*.pkl"))
model_files = glob.glob(os.path.join(folder, "*.pth"))

print("PKL files:", pkl_files)
print("Model files:", model_files)

# the load functionality automatically will build the pipeline object out of the three saved files
varix_loaded = acx.Maskix.load(outpath)
varix_loaded.predict(data=maskix_result.datasets)
varix_loaded.visualize()
varix_loaded.show_result()