# Tutorial: context aware learning of multiple modalities with mmcontext

## Contents of Tutorial

This tutorial demonstrates how to use the mmcontext package to preprocess single-cell data. We'll walk through the steps of:

1. Loading the Dataset
2. Generating Embeddings
3. Normalizing Embeddings
4. Aligning Embeddings
5. Constructing the Dataset

### 1. Loading the Dataset

The input data has to be an {class}`anndata.AnnData` object. First, we'll load the example dataset provided in data/small_cellxgene.h5ad. It is derived from cellxgene and contains cells of various tissues and celltypes from different studys. The scvi embedding included is provided by cellxgene and contains embeddings computed with the scvi variational autoencoder trained on the cellxgene corpus. 

In [13]:
# Import necessary libraries
import anndata

data_path = "../../data/test_data/small_cellxgene_data.h5ad"
# Load the dataset
adata = anndata.read_h5ad(data_path)


# Display basic information about the dataset
print(adata)

root - INFO - Loading the example dataset, which is taken from cellxgene...
AnnData object with n_obs × n_vars = 5600 × 1000
    obs: 'soma_joinid', 'donor_id', 'disease', 'sex', 'dataset_id', 'cell_type', 'assay', 'tissue', 'cell_type_ontology_term_id', 'assay_ontology_term_id', 'tissue_ontology_term_id', 'suspension_type', 'is_primary_data'
    var: 'soma_joinid', 'feature_id', 'feature_name', 'feature_length', 'nnz', 'n_measured_obs'
    obsm: 'metadata_tissue_assay_cell_type', 'scvi'


### 2. Generate Embeddings

We will generate context embeddings based on the categorical metadata fields cell_type and tissue using the {class}`mmcontext.pp.CategoryEmbedder` class.
The method is based on creating an embedding of the individual categories of a cell with a text model. The `embeddings_file_path` points to a dictionary that contains embeddings for a range of cell types and tissues from the cellxgene corpus, thereby allowing the method to work without needing an API call. Only if some categories are not found in the dictionary the api will be used. If only a few are unknown, these will just be filled with a zero embedding. The `unkown_threshold` parameter controls how many new categories are needed to use the API. For that of course an API Key will be needed, which has to be set as an environmental variable "OPENAI_API_KEY". 

We will use the precomputed data embeddings stored in adata.obsm['scvi'] as our data embeddings

In [14]:
# Import the CategoryEmbedder class
from mmcontext.pp import CategoryEmbedder, Embedder

# Specify the categories to embed
categories = ["cell_type", "tissue"]

# Initialize the CategoryEmbedder
category_embedder = CategoryEmbedder(
    metadata_categories=categories,
    model="text-embedding-3-small",
    combination_method="average",
    embeddings_file_path="../../data/emb_dicts/category_embeddings_text-embedding-3-small_metadata_embeddings.pkl.gz",
)
# Initialize the Embedder without embedders
embedder = Embedder(context_embedder=category_embedder)

# Create embeddings using external embeddings
embedder.create_embeddings(adata, data_embeddings=adata.obsm["scvi"])

# Confirm the shape of the context embeddings
print("Context Embeddings Shape:", adata.obsm["c_emb"].shape)
print("Data Embeddings Shape:", adata.obsm["d_emb"].shape)

mmcontext.pp.context_embedder - INFO - Loaded embeddings from file.
mmcontext.pp.context_embedder - INFO - Embeddings dictionary contains the following categories: dict_keys(['cell_type', 'tissue', 'assay']) with a total of 947 elements.
mmcontext.pp.embedder - INFO - Using external data embeddings provided.
mmcontext.pp.embedder - INFO - Creating context embeddings...
mmcontext.pp.context_embedder - INFO - Embeddings for 'cell_type' stored in adata.obsm['cell_type_emb']
mmcontext.pp.context_embedder - INFO - Embeddings for 'tissue' stored in adata.obsm['tissue_emb']
mmcontext.pp.context_embedder - INFO - Combined context embeddings stored in adata.obsm['c_emb']
Context Embeddings Shape: (5600, 1536)
Data Embeddings Shape: (5600, 50)


### 3. Normalize Embeddings 

Now that the embeddings are created and stored in the adata object we can apply normalization. We will use the {class}`mmcontext.pp.MinMaxNormalizer` here. 

In [15]:
# Import the MinMaxNormalizer class
from mmcontext.pp import MinMaxNormalizer

# Initialize the MinMaxNormalizer
normalizer = MinMaxNormalizer()

# Normalize the embeddings
normalizer.normalize(adata)

# Confirm that normalized embeddings are stored
print("Normalized Data Embeddings Shape:", adata.obsm["d_emb_norm"].shape)
print("Normalized Context Embeddings Shape:", adata.obsm["c_emb_norm"].shape)

mmcontext.pp.embedding_normalizer - INFO - Normalizing embeddings using min-max normalization...
Normalized Data Embeddings Shape: (5600, 50)
Normalized Context Embeddings Shape: (5600, 1536)


### 4. Aligning Embeddings

After normalization we will now use a {class}`mmcontext.pp.DimAligner` to make sure the dimensions of both data and context embeddings are equal, as this will be nescessary for the model. We will use the {class}`mmcontext.pp.PCAReducer` for this. If embeddings are larger than target latent dimension, they will be reduced via PCA. If there are smaller, padding with zeros will be applied.

In [16]:
# Import the PCAReducer class
from mmcontext.pp import PCAReducer

# Initialize the PCAReducer with the desired latent dimension
latent_dim = 64
aligner = PCAReducer(latent_dim=latent_dim)

# Align the embeddings
aligner.align(adata)

# Confirm that aligned embeddings are stored
print("Aligned Data Embeddings Shape:", adata.obsm["d_emb_aligned"].shape)
print("Aligned Context Embeddings Shape:", adata.obsm["c_emb_aligned"].shape)

Aligned Data Embeddings Shape: (5600, 64)
Aligned Context Embeddings Shape: (5600, 64)


### 6. Constructing the Dataset

Finally, we will construct a dataset using the aligned embeddings, suitable for training models with PyTorch.

In [20]:
# Import the DataSetConstructor class
from mmcontext.pp import DataSetConstructor

# Initialize the DataSetConstructor
dataset_constructor = DataSetConstructor(sample_id_key="soma_joinid")

# Add the AnnData object to the dataset
dataset_constructor.add_anndata(adata)

# Construct the dataset
dataset = dataset_constructor.construct_dataset()

# Display information about the dataset
print("Total samples in dataset:", len(dataset))

Total samples in dataset: 5600
