# MMContextEncoder — Training Walkthrough
This notebook demonstrates how to finetune the multimodal `MMContextEncoder` in four flavours:
1. **Text‑only**
2. **Pre‑computed numeric embeddings**  
   ‑ 2 a. Feature‑level tokens  
   ‑ 2 b. Sample‑level tokens
3. **Random‑initialised baseline**

Small epochs & batch‑sizes keep runtime short – scale them up for real work.

In [16]:
%load_ext autoreload
%autoreload 2
from mmcontext.utils import setup_logging

setup_logging()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<RootLogger root (INFO)>

## 0  Setup & toy data

In [22]:
import os

import datasets
import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict
from sentence_transformers import SentenceTransformer

from mmcontext.models.mmcontextencoder import MMContextEncoder
from mmcontext.simulator import OmicsCaptionSimulator

# simulate tiny dataset
sim = OmicsCaptionSimulator(n_samples=2000, n_genes=12).simulate()
raw_ds = sim.get_hf_dataset()
raw_ds

Filter: 100%|██████████| 4000/4000 [00:00<00:00, 571119.83 examples/s]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 573403.60 examples/s]


DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'sample_idx'],
        num_rows: 3200
    })
    val: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'sample_idx'],
        num_rows: 800
    })
})

## Text‑only

In [21]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=None)
ds = enc.prepare_ds(raw_ds, prefix=False, cell_sentences_cols=["sentence1"], caption_col="sentence2")
st = SentenceTransformer(modules=[enc])
train, val = ds["train"], ds["val"]
out_dir = "./models/demo_text_only"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

2025-05-20 10:14:58,432 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps
                                                   
 37%|███▋      | 1484/4000 [01:22<01:25, 29.43it/s]2025-05-20 10:15:17,296 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_text_only/checkpoint-500
2025-05-20 10:15:17,296 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_text_only/checkpoint-500


{'loss': 0.0321, 'grad_norm': 0.11336958408355713, 'learning_rate': 0.0, 'epoch': 1.0}


                                                   
100%|██████████| 500/500 [00:19<00:00, 25.19it/s]s]

{'train_runtime': 19.8485, 'train_samples_per_second': 1612.209, 'train_steps_per_second': 25.191, 'train_loss': 0.03207863998413086, 'epoch': 1.0}





TrainOutput(global_step=500, training_loss=0.03207863998413086, metrics={'train_runtime': 19.8485, 'train_samples_per_second': 1612.209, 'train_steps_per_second': 25.191, 'total_flos': 0.0, 'train_loss': 0.03207863998413086, 'epoch': 1.0})

## Feature‑level tokens

In [23]:
sim = OmicsCaptionSimulator(n_samples=2000, n_genes=12, use_gene_level=True).simulate()
token_df = sim.get_dataframe()
raw_ds = sim.get_hf_dataset()

Filter: 100%|██████████| 4000/4000 [00:00<00:00, 572308.24 examples/s]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 578764.18 examples/s]


In [25]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64)
enc.register_initial_embeddings(token_df, data_origin="geneformer")
ds = enc.prepare_ds(raw_ds, cell_sentences_cols="sentence1", caption_col="sentence2")
train, val = ds["train"], ds["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_feat_tokens"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

2025-05-20 10:16:17,220 - mmcontext.models.omicsencoder - INFO - Loaded embedding matrix with shape (13, 16)
2025-05-20 10:16:17,221 - mmcontext.models.mmcontextencoder - INFO - Registered 13 new numeric samples (total 13). ≈0.000 GiB added. (Assuming float32 precision.)
Prefixing sentence1: 100%|██████████| 3200/3200 [00:00<00:00, 170667.10 examples/s]
Prefixing sentence1: 100%|██████████| 800/800 [00:00<00:00, 273980.83 examples/s]
2025-05-20 10:16:17,335 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps
2025-05-20 10:16:29,403 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_feat_tokens/checkpoint-400
2025-05-20 10:16:29,403 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_feat_tokens/checkpoint-400
                                                   
100%|██████████| 400/400 [00:12<00:00, 31.15it/s]s]

{'train_runtime': 12.8439, 'train_samples_per_second': 249.145, 'train_steps_per_second': 31.143, 'train_loss': 0.2536071968078613, 'epoch': 1.0}





TrainOutput(global_step=400, training_loss=0.2536071968078613, metrics={'train_runtime': 12.8439, 'train_samples_per_second': 249.145, 'train_steps_per_second': 31.143, 'total_flos': 0.0, 'train_loss': 0.2536071968078613, 'epoch': 1.0})

## Sample‑level tokens

In [26]:
sim = OmicsCaptionSimulator(n_samples=2000, n_genes=12, use_gene_level=False).simulate()
token_df = sim.get_dataframe()
raw_ds = sim.get_hf_dataset()

Filter: 100%|██████████| 4000/4000 [00:00<00:00, 593211.80 examples/s]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 591872.43 examples/s]


In [29]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64)
enc.register_initial_embeddings(token_df, data_origin="pca")
ds = enc.prepare_ds(raw_ds, cell_sentences_cols="sentence1", caption_col="sentence2")
train, val = ds["train"], ds["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_sample_tokens"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

2025-05-20 10:18:15,455 - mmcontext.models.omicsencoder - INFO - Loaded embedding matrix with shape (2001, 32)
2025-05-20 10:18:15,456 - mmcontext.models.mmcontextencoder - INFO - Registered 2001 new numeric samples (total 2001). ≈0.000 GiB added. (Assuming float32 precision.)
Prefixing sentence1: 100%|██████████| 3200/3200 [00:00<00:00, 161346.53 examples/s]
Prefixing sentence1: 100%|██████████| 800/800 [00:00<00:00, 258051.47 examples/s]
2025-05-20 10:18:15,595 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps
 98%|█████████▊| 49/50 [00:01<00:00, 32.73it/s]                      2025-05-20 10:18:17,402 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_sample_tokens/checkpoint-50
2025-05-20 10:18:17,402 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_sample_tokens/checkpoint-50
100%|██████████| 50/50 [00:02<00:00, 19.37it/s]

{'train_runtime': 2.5823, 'train_samples_per_second': 1239.194, 'train_steps_per_second': 19.362, 'train_loss': 0.25574018478393556, 'epoch': 1.0}





TrainOutput(global_step=50, training_loss=0.25574018478393556, metrics={'train_runtime': 2.5823, 'train_samples_per_second': 1239.194, 'train_steps_per_second': 19.362, 'total_flos': 0.0, 'train_loss': 0.25574018478393556, 'epoch': 1.0})

## Random baseline

In [32]:
enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=None, adapter_output_dim=128, freeze_text_encoder=True)
enc.random_initial_embeddings(list(token_df["token"]))
ds = enc.prepare_ds(raw_ds, cell_sentences_cols=["sentence1"], caption_col="sentence2")
train, val = ds["train"], ds["val"]
st = SentenceTransformer(modules=[enc])

2025-05-20 10:19:21,643 - mmcontext.models.omicsencoder - INFO - Loaded embedding matrix with shape (2001, 64)
2025-05-20 10:19:21,644 - mmcontext.models.mmcontextencoder - INFO - Registered 2001 new numeric samples (total 2001). ≈0.000 GiB added. (Assuming float32 precision.)
Prefixing sentence1: 100%|██████████| 3200/3200 [00:00<00:00, 161560.17 examples/s]
Prefixing sentence1: 100%|██████████| 800/800 [00:00<00:00, 258051.47 examples/s]
2025-05-20 10:19:21,779 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps


In [33]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

out_dir = "./models/demo_random"


loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=64,
    per_device_train_batch_size=2560,
    per_device_eval_batch_size=2560,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

 98%|█████████▊| 126/128 [00:06<00:00, 23.33it/s]                    2025-05-20 10:19:33,787 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_random/checkpoint-128
2025-05-20 10:19:33,788 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_random/checkpoint-128
100%|██████████| 128/128 [00:07<00:00, 16.94it/s]

{'train_runtime': 7.5619, 'train_samples_per_second': 27083.051, 'train_steps_per_second': 16.927, 'train_loss': 0.24992215633392334, 'epoch': 64.0}





TrainOutput(global_step=128, training_loss=0.24992215633392334, metrics={'train_runtime': 7.5619, 'train_samples_per_second': 27083.051, 'train_steps_per_second': 16.927, 'total_flos': 0.0, 'train_loss': 0.24992215633392334, 'epoch': 64.0})

### Further notes
* Swap in other losses or multi‑positive datasets – see the [Sentence‑Transformers training docs](https://www.sbert.net/docs/training/overview.html).
* Current ST release only recognises a second column named **`pixel_values`**. If you don't use my SentenceTransformer fork, rename `cell_sentence_2` (or whatever column of the dataset you want to use) to "
`pixel_values` accordingly.