# MMContextEncoder — quick‑start & usage tour

This notebook uses the **`OmicsCaptionSimulator`** to generate toy data and walks through three ways of running the `MMContextEncoder` inside the Sentence‑Transformers framework:

1. **Text‑only** (no numeric data)
2. **Pre‑computed numeric embeddings**  
   2 a. feature‑level tokens  2 b. sample‑level tokens
3. **Random‑initialised numeric embeddings** (baseline)

> *Training* will be covered in a follow‑up notebook. Here we focus on end‑to‑end **`encode`** calls and what comes out.

---

## 0  Setup

In [37]:
%load_ext autoreload
%autoreload 2
from mmcontext.utils import setup_logging

setup_logging()

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


<RootLogger root (INFO)>

In [41]:
import numpy as np
import pandas as pd
import torch
from sentence_transformers import SentenceTransformer

from mmcontext.models.MMContextEncoder import MMContextEncoder
from mmcontext.simulator import OmicsCaptionSimulator

sim = OmicsCaptionSimulator(n_samples=100, n_genes=10).simulate()
gene_df, sample_df = sim.get_dataframes()
raw_ds = sim.get_hf_dataset()["train"]
raw_ds

Filter: 100%|██████████| 200/200 [00:00<00:00, 87838.83 examples/s]
Filter: 100%|██████████| 200/200 [00:00<00:00, 87181.54 examples/s]


Dataset({
    features: ['sample_idx', 'cell_sentence_1', 'cell_sentence_2', 'captions', 'label'],
    num_rows: 160
})

In [44]:
gene_df

Unnamed: 0,token,embedding
0,g1,"[-1.0454764366149902, -0.9293965101242065, 2.4..."
1,g2,"[-1.2294284105300903, -0.9434616565704346, 1.7..."
2,g3,"[-1.030511498451233, -1.0347418785095215, 1.81..."
3,g4,"[0.2129260152578354, 0.042285092175006866, 0.5..."
4,g5,"[-1.7854403257369995, 0.669002890586853, -1.51..."
5,g6,"[-0.46827974915504456, 1.2896097898483276, 0.8..."
6,g7,"[0.09945357590913773, 1.1276657581329346, 0.65..."
7,g8,"[-1.6976884603500366, 0.11778832972049713, -0...."
8,g9,"[0.019002540037035942, 0.7835861444473267, 1.7..."
9,g10,"[-0.49436676502227783, -0.7283456921577454, 1...."


In [45]:
raw_ds["cell_sentence_2"]

['g8 g1 g3 g5 g9 g4 g10 g2 g7 g6',
 'g8 g1 g3 g5 g9 g4 g10 g2 g7 g6',
 'g1 g3 g7 g6 g2 g4 g5 g9 g10 g8',
 'g1 g3 g7 g6 g2 g4 g5 g9 g10 g8',
 'g7 g8 g10 g2 g6 g5 g9 g3 g1 g4',
 'g7 g8 g10 g2 g6 g5 g9 g3 g1 g4',
 'g4 g2 g6 g8 g3 g9 g10 g1 g5 g7',
 'g4 g2 g6 g8 g3 g9 g10 g1 g5 g7',
 'g4 g10 g9 g6 g8 g5 g7 g1 g3 g2',
 'g4 g10 g9 g6 g8 g5 g7 g1 g3 g2',
 'g6 g8 g4 g7 g5 g10 g1 g9 g3 g2',
 'g6 g8 g4 g7 g5 g10 g1 g9 g3 g2',
 'g5 g1 g9 g10 g7 g3 g4 g6 g2 g8',
 'g5 g1 g9 g10 g7 g3 g4 g6 g2 g8',
 'g7 g8 g3 g1 g5 g10 g9 g4 g6 g2',
 'g7 g8 g3 g1 g5 g10 g9 g4 g6 g2',
 'g4 g7 g6 g2 g8 g10 g1 g3 g9 g5',
 'g4 g7 g6 g2 g8 g10 g1 g3 g9 g5',
 'g2 g1 g6 g3 g5 g10 g9 g7 g8 g4',
 'g2 g1 g6 g3 g5 g10 g9 g7 g8 g4',
 'g2 g1 g6 g9 g4 g5 g10 g3 g7 g8',
 'g2 g1 g6 g9 g4 g5 g10 g3 g7 g8',
 'g2 g5 g10 g8 g7 g4 g9 g3 g6 g1',
 'g2 g5 g10 g8 g7 g4 g9 g3 g6 g1',
 'g1 g7 g3 g2 g8 g6 g10 g5 g4 g9',
 'g1 g7 g3 g2 g8 g6 g10 g5 g4 g9',
 'g7 g2 g10 g5 g6 g9 g8 g1 g3 g4',
 'g7 g2 g10 g5 g6 g9 g8 g1 g3 g4',
 'g1 g5 g9 g4 g8 g2 

In [39]:
# The gene and sample dataframes have entries of the following dimensions:
print(f"Gene embeddings shape: {gene_df['embedding'].shape}")
print(f"Sample embeddings shape: {sample_df['embedding'].shape}")

Gene embeddings shape: (10,)
Sample embeddings shape: (100,)


The HuggingFace dataset has the columns
`sample_idx, 'cell_sentence_1', 'cell_sentence_2', captions, label`.

## 1  MMContextEncoder as a **pure text** model

In [9]:
text_enc = MMContextEncoder(text_encoder_name="prajjwal1/bert-tiny")  # any HF model works
st_text = SentenceTransformer(modules=[text_enc])

example = [raw_ds["cell_sentence_1"][0], raw_ds["captions"][0]]
print("input →", example)
print("embedding →", st_text.encode(example)[:5], "…")

input → ['S1', 'This is a Neuron from skin with scRNA-seq']
embedding → [[ 0.17870897 -0.18881485  0.1618496   0.2277055  -0.17832585  0.1028291
  -0.29022208 -0.00083416 -0.12848191  0.48940802 -0.16683114  0.06225635
  -0.00932869  0.33577704  0.10112365  0.2780464   0.22146592 -0.23232459
   0.14874886 -0.57021034  0.33698675  0.31272697 -0.05247811 -0.143991
   0.12462245  0.23459499  0.09060437 -0.13936508  0.4014689  -0.17929876
   0.01188798  0.04319223 -0.14096665  0.01878069  0.15557246  0.3154393
  -0.26579973 -0.20207444 -0.01926843 -0.02130661 -0.03827694 -0.16857629
   0.13937666  0.05459056  0.06019242 -0.12200861  0.11641523 -0.07209309
  -0.17957948 -0.08626628 -0.06949704 -0.16139907  0.04034781  0.1302222
   0.1533053   0.06499399  0.02499971  0.02033132  0.12005767  0.1827635
  -0.29521948 -0.03016078  0.2764179   0.12270568  0.25018543 -0.2104099
  -0.21742357  0.1748561  -0.07294507  0.19347528  0.26182407  0.28095782
   0.34255117  0.0946178  -0.12599748  0.039890

`cell_sentence_1` is **treated like ordinary words**, because we never registered numeric embeddings.

If you initialise with `output_token_embeddings=True` you can retrieve the per‑token vectors:

In [12]:
text_enc_tokens = MMContextEncoder("prajjwal1/bert-tiny", output_token_embeddings=True)
st_tokens = SentenceTransformer(modules=[text_enc_tokens])

res = st_tokens.encode(example, output_value="token_embeddings")
print(len(res))  # a list with length of batch size (2)
res[0].shape  # the first element is a tensor of shape (n_tokens, n_features)

  from matplotlib.rcsetup import interactive_bk, non_interactive_bk  # @UnresolvedImport
  from matplotlib.rcsetup import interactive_bk, non_interactive_bk  # @UnresolvedImport


2


torch.Size([4, 128])

## 2  Using **pre‑computed** numeric embeddings
### 2 a  Feature‑level (gene) tokens

In [28]:
enc_feat = MMContextEncoder(
    "prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64, output_token_embeddings=True
)
enc_feat.register_initial_embeddings(gene_df, data_origin="geneformer")

# prefix the dataset so the processor knows which column is omics
pref_ds = enc_feat.prefix_ds(raw_ds, col_id="cell_sentence_2")

st_feat = SentenceTransformer(modules=[enc_feat])
row = pref_ds[0]
print("input →", row["cell_sentence_2"])
encoding = st_feat.encode(row["cell_sentence_2"], output_value="sentence_embedding")
print("Pooled Embedding shape:", encoding.shape)
token_encoding = st_feat.encode(row["cell_sentence_2"], output_value="token_embeddings")
print("Token Embedding shape:", token_encoding.shape)

Prefixing 'cell_sentence_2': 100%|██████████| 200/200 [00:00<00:00, 99995.33 examples/s]

input → sample_idx:g8 g1 g3 g5 g9 g4 g10 g2 g7 g6
Pooled Embedding shape: (64,)
Token Embedding shape: torch.Size([10, 64])





### 2 b  Sample‑level tokens

In [33]:
enc_samp = MMContextEncoder(
    "prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64, output_token_embeddings=True
)
enc_samp.register_initial_embeddings(sample_df, data_origin="pca")

pref_ds2 = enc_samp.prefix_ds(raw_ds, col_id="cell_sentence_1")
st_samp = SentenceTransformer(modules=[enc_samp])
print("input →", pref_ds2[0]["cell_sentence_1"])
encoding = st_samp.encode(pref_ds2[0]["cell_sentence_1"])
print("Pooled Embedding shape:", encoding.shape)
token_encoding = st_samp.encode(pref_ds2[0]["cell_sentence_1"], output_value="token_embeddings")
print("Token Embedding shape:", token_encoding.shape)

Prefixing 'cell_sentence_1': 100%|██████████| 200/200 [00:00<00:00, 98561.96 examples/s]

input → sample_idx:S1
Pooled Embedding shape: (64,)
Token Embedding shape: torch.Size([1, 64])





The numeric vectors from `sample_df` are returned **unmodified** by the omics branch and then projected by the adapter.

> **Note**  Embedding weights are *not* saved with the model; only the adapter weights are. When you reload the model you must call `register_initial_embeddings` again with a compatible matrix.

## 3  Random‑initialised embeddings (baseline)

In [35]:
enc_rand = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32)
enc_rand.random_initial_embeddings(list(gene_df["token"]))
pref_ds3 = enc_rand.prefix_ds(raw_ds, col_id="cell_sentence_2")

st_rand = SentenceTransformer(modules=[enc_rand])
print(st_rand.encode(pref_ds3[0]["cell_sentence_2"]))

Prefixing 'cell_sentence_2': 100%|██████████| 200/200 [00:00<00:00, 95325.09 examples/s]

[-4.48181853e-02  4.88068089e-02 -7.71359950e-02 -1.94791742e-02
 -1.65305734e-01  2.07038909e-01 -9.51697305e-02  1.00460630e-02
 -2.05945764e-02 -8.27556103e-02  3.99987102e-02  2.23212019e-01
  4.08632506e-04  8.41276348e-02  9.08150151e-02  1.14676468e-01
 -3.14155594e-02 -1.24669306e-01  1.00069672e-01  5.08623570e-02
  1.08218171e-01  2.77469680e-02 -2.09165335e-01  2.63675805e-02
  2.91837787e-05  2.79316418e-02 -1.44141257e-01  1.72094017e-01
  2.09506765e-01  1.42657250e-01  2.66109496e-01  1.25485763e-01
  7.32748508e-02 -1.65309429e-01  4.83123846e-02 -1.13332473e-01
  7.18119442e-02 -2.47363463e-01  2.18929891e-02 -1.42093882e-01
 -1.46550447e-01 -7.67515227e-02  1.83114987e-02  3.88945173e-03
  4.82691042e-02 -3.01351190e-01  5.10653257e-02  3.41373175e-01
 -1.61226124e-01  3.45270857e-02 -2.84832209e-01 -1.15363762e-01
  3.26400856e-03 -1.97847039e-01  1.37544721e-01 -1.67341352e-01
 -1.19984224e-01  1.34880215e-01 -8.69596004e-02 -9.23934951e-02
 -1.14667505e-01 -1.10064




Random vectors let you benchmark how much pre‑computed representations help compared with an uninformed baseline (same dimension, same adapters).

## 4  What’s next?
* **Training** → use `SentenceTransformerTrainer` with `pref_ds`. Give the model a pair dataset (`label` = 1/0) and a suitable loss, e.g. `CosineSimilarityLoss`.
* **Saving / loading** → `st_rand.save(path)`   then   `SentenceTransformer(path)`. Numeric lookup tables are *not* stored—re‑register before inference.
* **Hub upload** → after training, `.push_to_hub()` works like for every Sentence‑Transformers model.

A dedicated training notebook will cover these steps in detail.