# MMContextEncoder — Training Walkthrough
This notebook demonstrates how to finetune the multimodal `MMContextEncoder` in four flavours:
1. **Text‑only**
2. **Pre‑computed numeric embeddings**  
   ‑ 2 a. Feature‑level tokens  
   ‑ 2 b. Sample‑level tokens
3. **Random‑initialised baseline**

Small epochs & batch‑sizes keep runtime short – scale them up for real work.

## 0  Setup & toy data

In [None]:
import os

import datasets
import numpy as np
import pandas as pd
import torch
from datasets import DatasetDict
from sentence_transformers import SentenceTransformer

from mmcontext.models.MMContextEncoder import MMContextEncoder
from mmcontext.simulator import OmicsCaptionSimulator

# simulate tiny dataset
sim = OmicsCaptionSimulator(n_samples=120, n_genes=12).simulate()
gene_df, sample_df = sim.get_numeric_datasets()
raw_ds = sim.get_hf_dataset()

# split + duplicate captions for a paired loss
ds = raw_ds.train_test_split(test_size=0.2, seed=42)
ds = DatasetDict({"train": ds["train"], "val": ds["test"]})
for split in ds:
    ds[split] = ds[split].rename_column("captions", "cell_sentence_1")
    ds[split] = ds[split].add_column("cell_sentence_2", ds[split]["cell_sentence_1"])
ds

## Text‑only

In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=None)
st = SentenceTransformer(modules=[enc])
train, val = ds["train"], ds["val"]
out_dir = "./models/demo_text_only"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
    extra_feature_keys=["cell_sentence_2"],  # rename to 'pixel_values' for pip release
)
trainer.train()

## Feature‑level tokens

In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64)
enc.register_initial_embeddings(gene_df, data_origin="geneformer")
pref = enc.prefix_ds(ds, col_id="omics_tokens")
train, val = pref["train"], pref["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_feat_tokens"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
    extra_feature_keys=["cell_sentence_2"],  # rename to 'pixel_values' for pip release
)
trainer.train()

## Sample‑level tokens

In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32, adapter_output_dim=64)
enc.register_initial_embeddings(sample_df, data_origin="pca")
pref = enc.prefix_ds(ds, col_id="omics_tokens_lvl2")
train, val = pref["train"], pref["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_sample_tokens"

loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
    extra_feature_keys=["cell_sentence_2"],  # rename to 'pixel_values' for pip release
)
trainer.train()

## Random baseline

In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments, losses

enc = MMContextEncoder("prajjwal1/bert-tiny", adapter_hidden_dim=32)
enc.random_initial_embeddings(list(gene_df["token"]))
pref = enc.prefix_ds(ds, col_id="omics_tokens")
train, val = pref["train"], pref["val"]
st = SentenceTransformer(modules=[enc])
out_dir = "./models/demo_random"


loss = losses.ContrastiveLoss(model=st)
args = SentenceTransformerTrainingArguments(
    output_dir=out_dir,
    num_train_epochs=1,
    per_device_train_batch_size=8,
    per_device_eval_batch_size=8,
)
trainer = SentenceTransformerTrainer(
    model=st,
    args=args,
    train_dataset=train,
    eval_dataset=val,
    loss=loss,
    extra_feature_keys=["cell_sentence_2"],  # rename to 'pixel_values' for pip release
)
trainer.train()

### Further notes
* Swap in other losses or multi‑positive datasets – see the [Sentence‑Transformers training docs](https://www.sbert.net/docs/training/overview.html).
* Current ST release only recognises a second column named **`pixel_values`**. If you don't use my SentenceTransformer fork, rename `cell_sentence_2` (or whatever column of the dataset you want to use) to "
`pixel_values` accordingly.