# MMContextEncoder — Training Walkthrough
This notebook demonstrates how to finetune the multimodal `MMContextEncoder` in four flavours:
1. **Text‑only**
2. **Pre‑computed numeric embeddings**  
   ‑ 2 a. Feature‑level tokens  
   ‑ 2 b. Sample‑level tokens
3. **Random‑initialised baseline**

Small epochs & batch‑sizes keep runtime short – scale them up for real work.

In [2]:
%load_ext autoreload
%autoreload 2
from mmcontext.utils import setup_logging

setup_logging()

  from .autonotebook import tqdm as notebook_tqdm
  warn(
  STOPWORDS = set(map(str.strip, open(os.path.join(FILE, 'stopwords')).readlines()))


<RootLogger root (INFO)>

## 0  Setup & toy data

In [3]:
import os

import datasets
import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict
from sentence_transformers import SentenceTransformer

from mmcontext.models.mmcontextencoder import MMContextEncoder
from mmcontext.simulator import OmicsCaptionSimulator

# simulate tiny dataset
sim = OmicsCaptionSimulator(n_samples=2000, n_genes=12).simulate(preset="pair-binary")
raw_ds = sim.get_hf_dataset()
raw_ds

2025-07-15 11:05:10,149 - root - INFO - Building HF dataset with preset: pair-binary
2025-07-15 11:05:10,149 - root - INFO - Available presets: single, single-class, pair, pair-binary, triplet
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 518215.17 examples/s]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 551410.50 examples/s]


DatasetDict({
    train: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'sample_idx'],
        num_rows: 3200
    })
    val: Dataset({
        features: ['sentence1', 'sentence2', 'label', 'sample_idx'],
        num_rows: 800
    })
})

## Text‑only

In [4]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=None)
ds = enc.prepare_ds(raw_ds, prefix=False, primary_cell_sentence_col="sentence1", caption_col="sentence2")
st = SentenceTransformer(modules=[enc])
train, val = ds["train"], ds["val"]
out_dir = "./models/demo_text_only"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

2025-07-15 11:05:12,143 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
	- Avoid using `tokenizers` before the fork if possible
	- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
[34m[1mwandb[0m: Currently logged in as: 



Step,Training Loss


2025-07-15 11:05:16,717 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_text_only/checkpoint-50
2025-07-15 11:05:16,717 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_text_only/checkpoint-50


TrainOutput(global_step=50, training_loss=0.03247499465942383, metrics={'train_runtime': 4.8861, 'train_samples_per_second': 654.919, 'train_steps_per_second': 10.233, 'total_flos': 0.0, 'train_loss': 0.03247499465942383, 'epoch': 1.0})

## Feature‑level tokens

In [5]:
sim = OmicsCaptionSimulator(n_samples=2000, n_genes=12, use_gene_level=True).simulate()
token_df = sim.get_dataframe()
raw_ds = sim.get_hf_dataset()

2025-07-15 11:05:19,789 - root - INFO - Building HF dataset with preset: pair-binary
2025-07-15 11:05:19,790 - root - INFO - Available presets: single, single-class, pair, pair-binary, triplet
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 518407.32 examples/s]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 533898.17 examples/s]


In [8]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64)
enc.register_initial_embeddings(token_df, data_origin="geneformer")
# Include the prefix into the dataset. This will tell the model that this input is not treated as a normal string
# Prefixed inputs are used for the omics side of the model, which is a lookup-only encoder, built with intial embeddings of other models, eg. geneformer.
ds = enc.prepare_ds(raw_ds, prefix=True, primary_cell_sentence_col="sentence1", caption_col="sentence2")
train, val = ds["train"], ds["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_feat_tokens"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

2025-07-15 11:07:14,590 - mmcontext.models.omicsencoder - INFO - Loaded embedding matrix with shape (13, 16)
2025-07-15 11:07:14,591 - mmcontext.models.mmcontextencoder - INFO - Registered 13 new numeric samples (total 13). ≈0.000 GiB added. (Assuming float32 precision.)
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Prefixing sentence1: 100%|██████████| 3200/3200 [00:00<00:00, 170556.49 examples/s]
Prefixing sentence1: 100%|██████████| 800/800 [00:00<00:00, 194552.28 examples/s]
2025-07-15 11:07:14,715 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps


Step,Training Loss


2025-07-15 11:07:28,155 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_feat_tokens/checkpoint-400
2025-07-15 11:07:28,156 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_feat_tokens/checkpoint-400


TrainOutput(global_step=400, training_loss=0.2524380111694336, metrics={'train_runtime': 14.2909, 'train_samples_per_second': 223.919, 'train_steps_per_second': 27.99, 'total_flos': 0.0, 'train_loss': 0.2524380111694336, 'epoch': 1.0})

## Sample‑level tokens

In [9]:
sim = OmicsCaptionSimulator(n_samples=2000, n_genes=12, use_gene_level=False).simulate()
token_df = sim.get_dataframe()
raw_ds = sim.get_hf_dataset()

2025-07-15 11:07:59,717 - root - INFO - Building HF dataset with preset: pair-binary
2025-07-15 11:07:59,718 - root - INFO - Available presets: single, single-class, pair, pair-binary, triplet
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 543391.61 examples/s]
Filter: 100%|██████████| 4000/4000 [00:00<00:00, 529116.19 examples/s]


In [11]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64)
enc.register_initial_embeddings(token_df, data_origin="pca")
ds = enc.prepare_ds(raw_ds, primary_cell_sentence_col="sentence1", caption_col="sentence2")
train, val = ds["train"], ds["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_sample_tokens"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=64,
    per_device_eval_batch_size=64,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()

2025-07-15 11:08:19,641 - mmcontext.models.omicsencoder - INFO - Loaded embedding matrix with shape (2001, 32)
2025-07-15 11:08:19,641 - mmcontext.models.mmcontextencoder - INFO - Registered 2001 new numeric samples (total 2001). ≈0.000 GiB added. (Assuming float32 precision.)
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Prefixing sentence1: 100%|██████████| 3200/3200 [00:00<00:00, 168076.80 examples/s]
Prefixing sentence1: 100%|██████████| 800/800 [00:00<00:00, 239093.86 examples/s]
2025-07-15 11:08:19,764 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps


Step,Training Loss


2025-07-15 11:08:22,087 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_sample_tokens/checkpoint-50
2025-07-15 11:08:22,088 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_sample_tokens/checkpoint-50


TrainOutput(global_step=50, training_loss=0.25470958709716796, metrics={'train_runtime': 3.5607, 'train_samples_per_second': 898.693, 'train_steps_per_second': 14.042, 'total_flos': 0.0, 'train_loss': 0.25470958709716796, 'epoch': 1.0})

## Random baseline

In [14]:
enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=None, adapter_output_dim=128, freeze_text_encoder=True)
enc.random_initial_embeddings(list(token_df["token"]))
ds = enc.prepare_ds(raw_ds, primary_cell_sentence_col="sentence1", caption_col="sentence2")
train, val = ds["train"], ds["val"]
st = SentenceTransformer(modules=[enc])

2025-07-15 11:09:11,430 - mmcontext.models.omicsencoder - INFO - Loaded embedding matrix with shape (2001, 64)
2025-07-15 11:09:11,431 - mmcontext.models.mmcontextencoder - INFO - Registered 2001 new numeric samples (total 2001). ≈0.000 GiB added. (Assuming float32 precision.)
  obj.co_lnotab,  # for < python 3.10 [not counted in args]
Prefixing sentence1: 100%|██████████| 3200/3200 [00:00<00:00, 177431.06 examples/s]
Prefixing sentence1: 100%|██████████| 800/800 [00:00<00:00, 250144.86 examples/s]
2025-07-15 11:09:11,541 - sentence_transformers.SentenceTransformer - INFO - Use pytorch device_name: mps


In [15]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

out_dir = "./models/demo_random"


loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=64,
    per_device_train_batch_size=2560,
    per_device_eval_batch_size=2560,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
)
trainer.train()



Step,Training Loss


2025-07-15 11:09:57,111 - sentence_transformers.trainer - INFO - Saving model checkpoint to ./models/demo_random/checkpoint-128
2025-07-15 11:09:57,111 - sentence_transformers.SentenceTransformer - INFO - Save model to ./models/demo_random/checkpoint-128


TrainOutput(global_step=128, training_loss=0.25067639350891113, metrics={'train_runtime': 8.715, 'train_samples_per_second': 23499.648, 'train_steps_per_second': 14.687, 'total_flos': 0.0, 'train_loss': 0.25067639350891113, 'epoch': 64.0})

### Further notes
* Swap in other losses or multi‑positive datasets – see the [Sentence‑Transformers training docs](https://www.sbert.net/docs/training/overview.html).
* Internally ST gives names to each input features. Text tokens are "input_ids". The only othe option supported by default are "pixel_values", which i use here for the tokens of my omics sample.