This tutorial shows how to evaluate an MMContext Sentence Transformers model on one dataset. It assumes you created a huggingface dataset, which contains the cell representations (either cell ids for numerical embeddings or cell sentences for text_only usage). Such datasets can be created with a pipeline available through the https://github.com/mengerj/adata_hf_datasets repo. If you instead want to start from an adata object, see the tutorial pretrained_inference.ipynb

Figure 1D in the publication was created with this notebook

In [93]:
import pandas as pd
from datasets import load_dataset

repo_name = "jo-mengr"
dataset_name = "hiha_100k"
split_name = "test"
label_key = "AIFI_L1"  # "AIFI_L2"

In [94]:
dataset = load_dataset(f"{repo_name}/{dataset_name}")
test_dataset = dataset[split_name]

In [95]:
from sentence_transformers import SentenceTransformer

model_name = "jo-mengr/mmcontext-pubmedbert-gs10k"
model = SentenceTransformer(model_name, trust_remote_code=True)
data_type = "gs10k"
layer_key = f"X_{data_type}"
text_only = False
primary_cell_sentence = "cell_sentence_1"  # set to cell_sentence_2 for text based models

You are trying to use a model that was created with Sentence Transformers version 5.1.2, but you're currently using version 5.0.0. This might cause unexpected behavior or errors. In that case, try to update to the latest version.
Loaded encoder was registered for 'gs10k' data. Call register_initial_embeddings() with compatible data before using it.


In [92]:
import os

from mmcontext.file_utils import load_test_adata_from_hf_dataset, subset_dataset_by_chunk

adata, local_path = load_test_adata_from_hf_dataset(
    test_dataset,
    save_dir=f"../data/test_adata/{dataset_name}",
    zenodo_token=os.getenv("ZENODO_TOKEN"),
)
adata, dataset_sub = subset_dataset_by_chunk(adata, test_dataset)

Processing:   0%|          | 0/1 [00:00<?, ?file/s]

  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  obj.co_lnotab,  # for < python 3.10 [not counted in args]


Filter:   0%|          | 0/92874 [00:00<?, ? examples/s]

In [7]:
if not text_only:
    token_df, _ = model[0].get_initial_embeddings_from_adata_link(
        dataset_sub,
        layer_key=layer_key,
        download_dir=f"../data/test_adata/{dataset_name}",
        axis="obs",
    )
    model[0].register_initial_embeddings(token_df, data_origin=data_type)

Processing:   0%|          | 0/1 [00:00<?, ?file/s]

  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)
  return self.callback(read_func, elem.name, elem, iospec=iospec)


Use the returned DataFrame to register the embeddings with `register_initial_embeddings()`.


In [8]:
from datasets import DatasetDict

from mmcontext.utils import truncate_cell_sentences

# enc.register_initial_embeddings(token_df, data_origin="geneformer")
if not text_only:
    dataset_ready = model[0].prefix_ds(dataset_sub, primary_cell_sentence)
else:
    dataset_split = truncate_cell_sentences(
        dataset_sub[split_name], primary_cell_sentence, max_length=64, filter_strings=["RPS", "RPL", "MT"]
    )
    dataset_ready = DatasetDict({split_name: dataset_split})

  obj.co_lnotab,  # for < python 3.10 [not counted in args]


In [9]:
dataset_to_use = dataset_ready  # [split_name]

In [10]:
dataset_to_use[0]

{'sample_idx': 'cfe6b50ea2e611eb91855e3d154f7c3f',
 'cell_sentence_1': 'sample_idx:cfe6b50ea2e611eb91855e3d154f7c3f',
 'cell_sentence_2': 'B2M ACTB RPS27 RPL32 RPS4X RPL30 RPS12 RPS8 S100A4 RPL34 RPS3 IL32 RPS26 RPS27A TMSB4X RPS3A TMSB10 SH3BGRL3 VIM MT-CO1 TRBC2 RPS6 RPS29 RPS13 S100A10 GAPDH RPL21 FTH1 MT-CYB TRAC RPS18 OAZ1 LTB MT-CO2 CD99 LDHB LGALS1 CD3D C12orf75 LGALS3 S100A11 CRIP1 ANXA2 FTL MT-ATP6 S100A6 HLA-DRB5 HLA-DRB1 DOK2 MYC GLIPR2 TXN HSPA5 OPTN PTPN6 HCST PIK3IP1 MT-ND3 ATP6V0B JUN GBP2 CD2 TKT TRAT1 ZFYVE28 CD74 RPL26L1 PRELID1 SOD2 CXCR3 AL162231.1 TALDO1 ITGB1 RTKN2 CCND2 TUBA1A PTGER2 TRADD HS3ST3B1 CD7 TGFB1 APOBEC3G TNFRSF18 TNFRSF4 TNFRSF1B FAM76A CSF3R UTP11 CDKN2C PGM1 PLEKHO1 C1orf56 S100A12 ISG20L2 TAGLN2 IGSF8 CD247 SELL PTPN7 FCMR ADI1 AC073195.1 UBXN2A WDR92 NAGK GCA CWC22 SPATS2L TUBA4A ARL4C PER2 OXNAD1 GOLGA4 DHX30 NPRL2 APPL1 CLDND1 TIGIT KALRN H1FX NUDT16 RPL39L DGKQ LYAR AC093323.1 MRPL1 ANXA5 SAP30 WDR70 GZMA LYRM7 IRF1 PRR7 ZNF879 PAK1IP1 HIST1H1

In [11]:
text_encoder_name = model[0].text_encoder_name
text_encoder = SentenceTransformer(text_encoder_name)

In [12]:
dataset_to_use

Dataset({
    features: ['sample_idx', 'cell_sentence_1', 'cell_sentence_2', 'adata_link'],
    num_rows: 92874
})

In [13]:
omics_embeddings = model.encode(dataset_to_use[primary_cell_sentence])
adata.obsm["mmcontext_emb"] = omics_embeddings

In [14]:
n_colours = len(adata.obs["AIFI_L1"].unique())

In [15]:
import seaborn as sns

auto_colors = sns.color_palette("tab10", n_colours)

In [16]:
auto_colors

In [17]:
label_colors = {
    "T cell": auto_colors[7],
    "B cell": auto_colors[0],
    "NK cell": auto_colors[1],
    "Monocyte": auto_colors[2],
    "DC": auto_colors[3],
    "Platelet": auto_colors[4],
    "Progenitor cell": auto_colors[5],
    "ILC": auto_colors[6],
    "Erythrocyte": auto_colors[8],
}

In [68]:
from mmcontext.eval import get

EvClass = get("LabelSimilarity")
ev = EvClass(
    auto_filter_labels=False,
    umap_n_neighbors=10,
    umap_min_dist=0.4,
    similarity="cosine",
    logit_scale=1,
    score_norm_method=None,
    label_colors=label_colors,
    annotation_fontsize=16,
    font_family="Arial",
)

In [39]:
# precompute umap coordinates to reuse on subset
full_omics_embeddings = adata.obsm["mmcontext_emb"]

In [20]:
# full_cell_umap = ev._compute_umap(full_omics_embeddings)
# add umap coordinates to adata
# adata.obsm["cell_umap"] = full_cell_umap

In [60]:
full_query_labels = adata.obs[label_key].unique().tolist()
full_label_embeddings = model.encode(full_query_labels)
full_true_labels = adata.obs[label_key]

In [65]:
from pathlib import Path

result = ev.compute(
    omics_embeddings=full_omics_embeddings,
    label_embeddings=full_label_embeddings,
    query_labels=full_query_labels,
    true_labels=full_true_labels,
    label_key=label_key,
    out_dir=Path(f"LabelSimilarity/{model_name}/{dataset_name}"),  # Pass output directory for caching
)

  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)


In [62]:
result

B cell/auc: 0.9997
B cell/accuracy: 0.9732
DC/auc: 0.9840
DC/accuracy: 0.7206
Erythrocyte/auc: 0.9055
Erythrocyte/accuracy: 0.6000
ILC/auc: 0.9656
ILC/accuracy: 0.5000
Monocyte/auc: 0.9996
Monocyte/accuracy: 0.9991
NK cell/auc: 0.9985
NK cell/accuracy: 0.9996
Platelet/auc: 1.0000
Platelet/accuracy: 1.0000
Progenitor cell/auc: 0.9976
Progenitor cell/accuracy: 0.9595
T cell/auc: 0.9896
T cell/accuracy: 0.7971
mean_auc: 0.9822
std_auc: 0.0292
accuracy: 0.8627
balanced_accuracy: 0.8388
random_baseline_accuracy: 0.1111
accuracy_over_random: 7.7640
n_labels: 9

In [69]:
ev.plot(
    omics_embeddings=full_omics_embeddings,
    # cell_umap=full_cell_umap,
    out_dir=Path(f"LabelSimilarity/{model_name}/{dataset_name}/{label_key}_combined"),
    label_embeddings=full_label_embeddings,
    query_labels=full_query_labels,
    true_labels=full_true_labels,
    label_key=label_key,  # column name (e.g. "celltype")
    save_format="svg",
    figsize=(2.5, 2.5),
    dpi=300,
    font_size=12,
    font_style="normal",
    font_weight="normal",
    legend_fontsize=54,
    axis_label_size=20,
    axis_tick_size=12,
    point_size=0.25,
    legend_layout="vertical",
    legend_point_size=16,
    umap_method="combined",
    label_min_distance=0.2,
    label_spring_strength=0.5,
    label_repulsion_strength=1,
)

  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  warn(
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()
  plt.tight_layout()


In [70]:
# Option to subset adata based on one or more label values (e.g., "Monocyte" and "DC")
subset_label_values = ["T cell"]  # Change this list to your desired label values
subset_label_key = "AIFI_L1"
annotation_label_key = "AIFI_L2"
# Subset the AnnData object for any of the specified label values
adata_subset = adata[adata.obs[subset_label_key].isin(subset_label_values)].copy()
subset_labels = adata_subset.obs[annotation_label_key].values.unique()
label_embeddings_subset = model.encode(subset_labels)
# Create a new LabelSimilarity evaluator instance
# ev_subset = EvClass(auto_filter_labels=False, umap_n_neighbors=15, umap_min_dist=0.5)
subset_label_string = "_".join(subset_label_values)
subset_omics_embeddings = adata_subset.obsm["mmcontext_emb"]
# subset_umap_coords = adata_subset.obsm["cell_umap"]
# ev_subset.eb_lfdr_q = 0.01
ev = EvClass(
    auto_filter_labels=False,
    umap_n_neighbors=10,
    umap_min_dist=0.4,
    similarity="cosine",
    logit_scale=1,
    score_norm_method=None,
    font_family="Arial",
    annotation_fontsize=18,
)
# Compute metrics on the subsetted data
result_subset = ev.compute(
    omics_embeddings=subset_omics_embeddings,
    label_embeddings=label_embeddings_subset,
    query_labels=subset_labels,
    true_labels=adata_subset.obs[annotation_label_key],
    label_key=annotation_label_key,
    out_dir=Path(
        f"LabelSimilarity/{model_name}/{dataset_name}/{annotation_label_key}_subset_{subset_label_string}/results"
    ),
)

# Plot results for the subset
ev.plot(
    omics_embeddings=subset_omics_embeddings,
    #    cell_umap=subset_umap_coords,
    out_dir=Path(
        f"LabelSimilarity/{model_name}/{dataset_name}/{annotation_label_key}_subset_{subset_label_string}_combined"
    ),
    label_embeddings=label_embeddings_subset,
    query_labels=subset_labels,
    true_labels=adata_subset.obs[annotation_label_key],
    label_key=annotation_label_key,
    save_format="svg",
    figsize=(2.5, 2.5),
    dpi=300,
    font_size=12,
    axis_tick_size=12,
    font_style="normal",
    font_weight="normal",
    axis_label_size=20,
    point_size=0.25,
    legend_layout="vertical",
    legend_point_size=20,
    umap_method="combined",
    label_min_distance=0.15,
    label_spring_strength=0.35,
    label_repulsion_strength=1.4,
)

  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  return (emb1_norm @ emb2_norm.T).astype(np.float32)
  warn(


In [69]:
result_subset

CD8aa/auc: 0.7231
CD8aa/accuracy: 0.0000
DN T cell/auc: 0.4027
DN T cell/accuracy: 0.0000
MAIT/auc: 0.9151
MAIT/accuracy: 0.0000
Memory CD4 T cell/auc: 0.8933
Memory CD4 T cell/accuracy: 0.8253
Memory CD8 T cell/auc: 0.9080
Memory CD8 T cell/accuracy: 0.8940
Naive CD4 T cell/auc: 0.8710
Naive CD4 T cell/accuracy: 0.8027
Naive CD8 T cell/auc: 0.9124
Naive CD8 T cell/accuracy: 0.9197
Proliferating T cell/auc: 0.9876
Proliferating T cell/accuracy: 0.4286
Treg/auc: 0.8976
Treg/accuracy: 0.5911
gdT/auc: 0.8419
gdT/accuracy: 0.0004
mean_auc: 0.8353
std_auc: 0.1578
accuracy: 0.7605
balanced_accuracy: 0.4462
random_baseline_accuracy: 0.1000
accuracy_over_random: 7.6051
n_labels: 10

In [None]:
# Visualise the embeddings
from mmcontext.pl import plot_umap
from mmcontext.utils import consolidate_low_frequency_categories

current_key = label_key
adata_cut = consolidate_low_frequency_categories(adata, [current_key], threshold=50, remove=True)
emb_key = "mmcontext_emb"
plot_umap(
    adata,
    color_key=label_key,
    embedding_key=emb_key,
    save_format="svg",
    save_dir=f"figs/{model_name}/{dataset_name}",
    save_plot=False,
    title="",
)

In [None]:
# Visualise the embeddings
from mmcontext.pl import plot_umap
from mmcontext.utils import consolidate_low_frequency_categories

current_key = label_key
adata_cut = consolidate_low_frequency_categories(adata, [current_key], threshold=1, remove=False)
emb_key = layer_key
plot_umap(
    adata_cut,
    color_key=label_key,
    embedding_key=emb_key,
    save_format="svg",
    nametag="",
    save_dir=f"figs/{model_name}/{dataset_name}",
    save_plot=False,
    title="",
)

In [None]:
from mmcontext.eval.query_annotate import OmicsQueryAnnotator

annotator = OmicsQueryAnnotator(model)
annotator.annotate_omics_data(adata, full_query_labels, emb_key="mmcontext_emb")

  similarity_matrix = data_emb @ label_emb.T
  similarity_matrix = data_emb @ label_emb.T
  similarity_matrix = data_emb @ label_emb.T


In [163]:
# get accuracy of best label vs true label
from sklearn.metrics import accuracy_score

accuracy_score(adata.obs["best_label"], adata.obs[label_key])

0.42632695122315484

In [159]:
if dataset_name == "human_immune_health_atlas_50k_single_no_caption":
    df = pd.read_csv("../../data/queries/additional_combined.csv")
    labels = df["Cell Type"]
    Definition = df["Definition"]
    from mmcontext.eval.query_annotate import OmicsQueryAnnotator
    from mmcontext.pl.plotting import plot_query_scores_with_labels_umap

    annotator = OmicsQueryAnnotator(model)
    annotator.query_with_text(adata, Definition, emb_key="mmcontext_emb")
    # Call the plotting function
    plot_query_scores_with_labels_umap(
        adata=adata,
        queries=Definition,
        labels=labels,
        label_key="AIFI_L2",
        save_dir=f"figs/{model_name}/{dataset_name}/umap_with_labels",
        nametag="",
        figsize=(4, 4),
        point_size=2,
        dpi=300,  # Lower DPI for faster generation
        axis_label_size=18,
        axis_tick_size=18,
    )

  similarity_matrix = query_emb @ data_emb.T
  similarity_matrix = query_emb @ data_emb.T
  similarity_matrix = query_emb @ data_emb.T
  warn(
