## Train a new mmcontext model (demonstrated on proteomics data)


In [24]:
%load_ext autoreload
%autoreload 2

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [25]:
from mmcontext.utils import setup_logging

setup_logging()

<RootLogger root (INFO)>

In [26]:
import anndata as ad

from mmcontext.file_utils import download_file_from_share_link

# a protein dataset from figshare (https://plus.figshare.com/articles/dataset/scPerturb_Single-Cell_Perturbation_Data_RNA_and_protein_h5ad_files/24160713?utm_source=chatgpt.com&file=42428325)
data_link = "https://plus.figshare.com/ndownloader/files/42428325"
local_path = "Frangiehlzar2021_protein.h5ad"
# download the data
download_file_from_share_link(share_link=data_link, save_path=local_path)
# load the data
adata = ad.read_h5ad(local_path)

2025-11-17 10:06:20,161 - mmcontext.file_utils - INFO - File is a valid .h5ad file.


Of course one can create a more sophisicated description, this is just an example

In [None]:
# Create a description string for each cell by looping over the rows in adata.obs
def make_description(row):
    """Make a quick description of a cell based on its metadata"""
    return (
        f"The first perturbation is {row['perturbation']} "
        f"and the second perturbation is {row['perturbation_2']}."
        f" The tissue is {row['tissue_type']} and it has cancer yes or no: {row['cancer']}."
        f" The disease is {row['disease']}."
        f" The celltype is {row['celltype']}."
    )


# Add a "description" column to adata.obs using the function above
adata.obs["description"] = adata.obs.apply(make_description, axis=1)
# Also add a sample index column for later
adata.obs["sample_idx"] = adata.obs.index

In [29]:
import numpy as np

# let's split by train and val, randomly 80% train
adata.obs["split"] = np.random.rand(len(adata)) < 0.8
adata_train = adata[adata.obs["split"]].copy()
adata_val = adata[~adata.obs["split"]].copy()

In [6]:
adata_val.shape

(43894, 24)

In [7]:
adata_train.shape

(174437, 24)

In [8]:
import scanpy as sc

# normalise and log transform the data
sc.pp.normalize_total(adata_train, inplace=True)
sc.pp.log1p(adata_train)

sc.pp.normalize_total(adata_val, inplace=True)
sc.pp.log1p(adata_val)

  return fn(*args_all, **kw)


In [9]:
# since these datasets only contain 24 proteins, we will just use their expression as embeddings
# we will use the protein names as the embedding keys
adata_train.obsm["X_prot"] = adata_train.X
adata_val.obsm["X_prot"] = adata_val.X
processed_paths = {"train": "Frangiehlzar2021_protein_pp_train.h5ad", "val": "Frangiehlzar2021_protein_pp_val.h5ad"}
adata_train.write_h5ad(processed_paths["train"])
adata_val.write_h5ad(processed_paths["val"])

... storing 'description' as categorical
... storing 'description' as categorical


In [10]:
from adata_hf_datasets import AnnDataSetConstructor
from datasets import DatasetDict

ds_dict = DatasetDict()
# multiplets format for training datasets (with descriptions)
# A sentence key refers to the column in adata.obs that is used to represent the sample.
# For numeric data, we use the sample index, and later register the created embedding linked to their indices in the tokenizer
for split_name, adata_split in {"train": adata_train, "val": adata_val}.items():
    constructor = AnnDataSetConstructor(dataset_format="multiplets", resolve_negatives=True)
    constructor.add_anndata(
        adata_split,
        caption_key="description",
        sentence_keys=["sample_idx"],
        adata_link=processed_paths[split_name],
        batch_key="library_preparation_protocol",  # In this case all are from the same batch, but providing a batch key can improve batch integration by negative sampling
    )
    ds = constructor.get_dataset()
    ds_dict[split_name] = ds

2025-11-14 15:45:09,802 - adata_hf_datasets.dataset.ds_constructor - INFO - AnnData ingested (174437 rows).
2025-11-14 15:45:09,805 - adata_hf_datasets.dataset.ds_constructor - INFO - Building dataset from generator...
  obj.co_lnotab,  # for < python 3.10 [not counted in args]


Generating train split: 0 examples [00:00, ? examples/s]

2025-11-14 15:45:11,849 - adata_hf_datasets.dataset.ds_constructor - INFO - Constructed dataset with 174437 records in 'multiplets' format.
2025-11-14 15:45:12,031 - adata_hf_datasets.dataset.ds_constructor - INFO - AnnData ingested (43894 rows).
2025-11-14 15:45:12,032 - adata_hf_datasets.dataset.ds_constructor - INFO - Building dataset from generator...


Generating train split: 0 examples [00:00, ? examples/s]

2025-11-14 15:45:12,620 - adata_hf_datasets.dataset.ds_constructor - INFO - Constructed dataset with 43894 records in 'multiplets' format.


In [11]:
ds

Dataset({
    features: ['sample_idx', 'cell_sentence_1', 'positive', 'negative_1_idx', 'adata_link', 'negative_1'],
    num_rows: 43894
})

## Configure the Model
The custom sentence transformers model will allow training with the sentence transformers Trainer. The numeric data has to be registered with the model, such that the representations in cell_sentence_1 (cell token eg. sample indices) are linked to the respective numeric vector, which serves as the initial repesentation of that sample. 
This is achieved by building a lookup table (cell token -> id) and a frozen embedding layer (id --> numeric vector). 

In [12]:
from sentence_transformers import SentenceTransformer

from mmcontext.mmcontextencoder import MMContextEncoder

enc = MMContextEncoder(
    text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
    adapter_hidden_dim=128,
    adapter_output_dim=64,
    freeze_text_encoder=True,
    unfreeze_last_n_layers=2,
    output_token_embeddings=False,
    train_lookup=False,
    joint_adapter_hidden_dim=None,
    text_model_kwargs=None,
    use_text_adapter=True,
)
model = SentenceTransformer(modules=[enc])

2025-11-14 15:45:13,529 - mmcontext.mmcontextencoder - INFO - Unfreezing last 2 layers of BERT-like model
2025-11-14 15:45:13,530 - mmcontext.mmcontextencoder - INFO - Successfully unfroze 2 layers with 3548928 trainable parameters
2025-11-14 15:45:13,560 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps


In [13]:
token_df, _ = model[0].get_initial_embeddings_from_adata_link(
    ds_dict,
    layer_key="X_prot",
    download_dir="data_cache",
    axis="obs",  # since we get embeddings from adata.obsm. We could also use "varm" and for example use an embedding for each protein
)
model[0].register_initial_embeddings(token_df, data_origin="prot")

2025-11-14 15:45:13,776 - mmcontext.file_utils - INFO - Found 2 unique share links across ['train', 'val'] splits


Processing:   0%|          | 0/2 [00:00<?, ?file/s]

2025-11-14 15:45:13,779 - mmcontext.file_utils - INFO - Using local path: /Users/mengerj/repos/mmcontext/tutorials/Frangiehlzar2021_protein_pp_train.h5ad
2025-11-14 15:45:13,779 - mmcontext.file_utils - INFO - Using local path: /Users/mengerj/repos/mmcontext/tutorials/Frangiehlzar2021_protein_pp_val.h5ad
2025-11-14 15:45:13,837 - mmcontext.file_utils - INFO - Reading /Users/mengerj/repos/mmcontext/tutorials/Frangiehlzar2021_protein_pp_train.h5ad
2025-11-14 15:45:15,130 - mmcontext.file_utils - INFO - Built DataFrame with 174437 rows × 24-dim embeddings
2025-11-14 15:45:15,150 - mmcontext.file_utils - INFO - Reading /Users/mengerj/repos/mmcontext/tutorials/Frangiehlzar2021_protein_pp_val.h5ad
2025-11-14 15:45:15,480 - mmcontext.file_utils - INFO - Built DataFrame with 43894 rows × 24-dim embeddings
2025-11-14 15:45:15,483 - mmcontext.mmcontextencoder - INFO - Combined embedding DataFrame shape: (218331, 3)
2025-11-14 15:45:15,703 - mmcontext.omicsencoder - INFO - Loaded embedding matrix

Use the returned DataFrame to register the embeddings with `register_initial_embeddings()`.


In [14]:
# the model expects a certain prefix on the cell tokens.
model[0].processor.prefix

'sample_idx:'

In [15]:
# you could add this manually or use the method below
model[0].prefix_ds(ds_dict, columns_to_prefix=["cell_sentence_1"])

  obj.co_lnotab,  # for < python 3.10 [not counted in args]


Prefixing columns: ['cell_sentence_1']:   0%|          | 0/174437 [00:00<?, ? examples/s]

Prefixing columns: ['cell_sentence_1']:   0%|          | 0/43894 [00:00<?, ? examples/s]

DatasetDict({
    train: Dataset({
        features: ['sample_idx', 'cell_sentence_1', 'positive', 'negative_1_idx', 'adata_link', 'negative_1'],
        num_rows: 174437
    })
    val: Dataset({
        features: ['sample_idx', 'cell_sentence_1', 'positive', 'negative_1_idx', 'adata_link', 'negative_1'],
        num_rows: 43894
    })
})

In [16]:
# lastly we have to drop some columns from the dataset and
# rename the main column to "anchor".
# you might think that this is a bit cumbersome, which it is.
# But this setup allowed for fleixble training,
# using either cell or feature level tokens, using text based cell sentences
# or numeric embeddings and resolving negatives
# to whatever column was chosen for training.
# That means that for a certain training run, the same dataset can be reused, and
# only modified differently. But in the end, it is a bit of work to set up.
ds_final = ds_dict.rename_column("cell_sentence_1", "anchor")
ds_final = ds_final.remove_columns(["sample_idx", "adata_link", "negative_1_idx"])
ds_final

DatasetDict({
    train: Dataset({
        features: ['anchor', 'positive', 'negative_1'],
        num_rows: 174437
    })
    val: Dataset({
        features: ['anchor', 'positive', 'negative_1'],
        num_rows: 43894
    })
})

In [20]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

args = SentenceTransformerTrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    fp16=False,
    bf16=False,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=1,
    max_grad_norm=1.0,
    logging_steps=10,
    run_name="protein_test",
)

In [21]:
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss

loss = MultipleNegativesRankingLoss(model)
evaluator = TripletEvaluator(
    anchors=ds_final["val"]["anchor"],
    positives=ds_final["val"]["positive"],
    negatives=ds_final["val"]["negative_1"],
    name="val",
)

In [22]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=ds_final["train"],
    eval_dataset=ds_final["val"],
    loss=loss,
    evaluator=evaluator,
)

Computing widget examples:   0%|          | 0/1 [00:00<?, ?example/s]

In [23]:
trainer.train()

huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: [33mmengerj[0m ([33mmengerj-universit-tsklinikum-freiburg[0m) to [32mhttps://api.wandb.ai[0m. Use [1m`wandb login --relogin`[0m to force relogin
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment v



Step,Training Loss,Validation Loss,Val Cosine Accuracy
100,6.7159,5.399902,0.77874
200,4.6499,4.97713,0.857315
300,4.345,4.811664,0.857725
400,4.2647,4.926792,0.785369
500,4.1591,5.027446,0.737003
600,4.1414,4.972731,0.742106
700,4.1395,4.974823,0.736866
800,4.0755,4.937728,0.729097
900,4.0738,5.088612,0.674147
1000,4.0443,5.126331,0.662733


2025-11-14 15:46:23,048 - sentence_transformers.evaluation.TripletEvaluator - INFO - TripletEvaluator: Evaluating the model on the val dataset in epoch 0.07336757153338225 after 100 steps:
2025-11-14 15:47:32,215 - sentence_transformers.evaluation.TripletEvaluator - INFO - Accuracy Cosine Similarity:	77.87%
2025-11-14 15:47:32,217 - sentence_transformers.trainer - INFO - Saving model checkpoint to trainer_output/checkpoint-100
2025-11-14 15:47:32,217 - sentence_transformers.SentenceTransformer - INFO - Save model to trainer_output/checkpoint-100
2025-11-14 15:48:17,890 - sentence_transformers.evaluation.TripletEvaluator - INFO - TripletEvaluator: Evaluating the model on the val dataset in epoch 0.1467351430667645 after 200 steps:
2025-11-14 15:49:31,975 - sentence_transformers.evaluation.TripletEvaluator - INFO - Accuracy Cosine Similarity:	85.73%
2025-11-14 15:49:31,978 - sentence_transformers.trainer - INFO - Saving model checkpoint to trainer_output/checkpoint-200
2025-11-14 15:49:3

TrainOutput(global_step=1363, training_loss=4.516945848024163, metrics={'train_runtime': 1529.0258, 'train_samples_per_second': 114.084, 'train_steps_per_second': 0.891, 'total_flos': 0.0, 'train_loss': 4.516945848024163, 'epoch': 1.0})