This tutorial shows how to evaluate an MMContext Sentence Transformers model on one dataset. It assumes you created a huggingface dataset, which contains the cell representations (either cell ids for numerical embeddings or cell sentences for text_only usage). Such datasets can be created with a pipeline available through the https://github.com/mengerj/adata_hf_datasets repo. If you instead want to start from an adata object, see the tutorial pretrained_inference.ipynb

Figure 1D in the publication was created with this notebook

In [None]:
import pandas as pd
from datasets import load_dataset

repo_name = "jo-mengr"
dataset_name = "hiha_100k"
split_name = "test"
label_key = "AIFI_L1"  # "AIFI_L2"

In [None]:
dataset = load_dataset(f"{repo_name}/{dataset_name}")
test_dataset = dataset[split_name]

In [None]:
from sentence_transformers import SentenceTransformer

model_name = "jo-mengr/mmcontext-pubmedbert-gs10k"
model = SentenceTransformer(model_name, trust_remote_code=True)
data_type = "gs10k"
layer_key = f"X_{data_type}"
text_only = False
primary_cell_sentence = "cell_sentence_1"  # set to cell_sentence_2 for text based models

In [None]:
import os

from mmcontext.file_utils import load_test_adata_from_hf_dataset, subset_dataset_by_chunk

adata, local_path = load_test_adata_from_hf_dataset(
    test_dataset,
    save_dir=f"../data/test_adata/{dataset_name}",
    zenodo_token=os.getenv("ZENODO_TOKEN"),
)
adata, dataset_sub = subset_dataset_by_chunk(adata, test_dataset)

In [None]:
if not text_only:
    token_df, _ = model[0].get_initial_embeddings_from_adata_link(
        dataset_sub,
        layer_key=layer_key,
        download_dir=f"../data/test_adata/{dataset_name}",
        axis="obs",
    )
    model[0].register_initial_embeddings(token_df, data_origin=data_type)

In [None]:
from datasets import DatasetDict

from mmcontext.utils import truncate_cell_sentences

# enc.register_initial_embeddings(token_df, data_origin="geneformer")
if not text_only:
    dataset_ready = model[0].prefix_ds(dataset_sub, primary_cell_sentence)
else:
    dataset_split = truncate_cell_sentences(
        dataset_sub[split_name], primary_cell_sentence, max_length=64, filter_strings=["RPS", "RPL", "MT"]
    )
    dataset_ready = DatasetDict({split_name: dataset_split})

In [None]:
dataset_to_use = dataset_ready  # [split_name]

In [None]:
dataset_to_use[0]

In [None]:
text_encoder_name = model[0].text_encoder_name
text_encoder = SentenceTransformer(text_encoder_name)

In [None]:
dataset_to_use

In [None]:
omics_embeddings = model.encode(dataset_to_use[primary_cell_sentence])
adata.obsm["mmcontext_emb"] = omics_embeddings

In [None]:
n_colours = len(adata.obs["AIFI_L1"].unique())

In [None]:
import seaborn as sns

auto_colors = sns.color_palette("tab10", n_colours)

In [None]:
auto_colors

In [None]:
label_colors = {
    "T cell": auto_colors[7],
    "B cell": auto_colors[0],
    "NK cell": auto_colors[1],
    "Monocyte": auto_colors[2],
    "DC": auto_colors[3],
    "Platelet": auto_colors[4],
    "Progenitor cell": auto_colors[5],
    "ILC": auto_colors[6],
    "Erythrocyte": auto_colors[8],
}

In [None]:
from mmcontext.eval import get

EvClass = get("LabelSimilarity")
ev = EvClass(
    auto_filter_labels=False,
    umap_n_neighbors=10,
    umap_min_dist=0.4,
    similarity="cosine",
    logit_scale=1,
    score_norm_method=None,
    label_colors=label_colors,
    annotation_fontsize=16,
    font_family="Arial",
)

In [None]:
# precompute umap coordinates to reuse on subset
full_omics_embeddings = adata.obsm["mmcontext_emb"]

In [None]:
# full_cell_umap = ev._compute_umap(full_omics_embeddings)
# add umap coordinates to adata
# adata.obsm["cell_umap"] = full_cell_umap

In [None]:
full_query_labels = adata.obs[label_key].unique().tolist()
full_label_embeddings = model.encode(full_query_labels)
full_true_labels = adata.obs[label_key]

In [None]:
from pathlib import Path

result = ev.compute(
    omics_embeddings=full_omics_embeddings,
    label_embeddings=full_label_embeddings,
    query_labels=full_query_labels,
    true_labels=full_true_labels,
    label_key=label_key,
    out_dir=Path(f"LabelSimilarity/{model_name}/{dataset_name}"),  # Pass output directory for caching
)

In [None]:
result

In [None]:
ev.plot(
    omics_embeddings=full_omics_embeddings,
    # cell_umap=full_cell_umap,
    out_dir=Path(f"LabelSimilarity/{model_name}/{dataset_name}/{label_key}_combined"),
    label_embeddings=full_label_embeddings,
    query_labels=full_query_labels,
    true_labels=full_true_labels,
    label_key=label_key,  # column name (e.g. "celltype")
    save_format="svg",
    figsize=(2.5, 2.5),
    dpi=300,
    font_size=12,
    font_style="normal",
    font_weight="normal",
    legend_fontsize=54,
    axis_label_size=20,
    axis_tick_size=12,
    point_size=0.25,
    legend_layout="vertical",
    legend_point_size=16,
    umap_method="combined",
    label_min_distance=0.2,
    label_spring_strength=0.5,
    label_repulsion_strength=1,
)

In [None]:
# Option to subset adata based on one or more label values (e.g., "Monocyte" and "DC")
subset_label_values = ["T cell"]  # Change this list to your desired label values
subset_label_key = "AIFI_L1"
annotation_label_key = "AIFI_L2"
# Subset the AnnData object for any of the specified label values
adata_subset = adata[adata.obs[subset_label_key].isin(subset_label_values)].copy()
subset_labels = adata_subset.obs[annotation_label_key].values.unique()
label_embeddings_subset = model.encode(subset_labels)
# Create a new LabelSimilarity evaluator instance
# ev_subset = EvClass(auto_filter_labels=False, umap_n_neighbors=15, umap_min_dist=0.5)
subset_label_string = "_".join(subset_label_values)
subset_omics_embeddings = adata_subset.obsm["mmcontext_emb"]
# subset_umap_coords = adata_subset.obsm["cell_umap"]
# ev_subset.eb_lfdr_q = 0.01
ev = EvClass(
    auto_filter_labels=False,
    umap_n_neighbors=10,
    umap_min_dist=0.4,
    similarity="cosine",
    logit_scale=1,
    score_norm_method=None,
    font_family="Arial",
    annotation_fontsize=18,
)
# Compute metrics on the subsetted data
result_subset = ev.compute(
    omics_embeddings=subset_omics_embeddings,
    label_embeddings=label_embeddings_subset,
    query_labels=subset_labels,
    true_labels=adata_subset.obs[annotation_label_key],
    label_key=annotation_label_key,
    out_dir=Path(
        f"LabelSimilarity/{model_name}/{dataset_name}/{annotation_label_key}_subset_{subset_label_string}/results"
    ),
)

# Plot results for the subset
ev.plot(
    omics_embeddings=subset_omics_embeddings,
    #    cell_umap=subset_umap_coords,
    out_dir=Path(
        f"LabelSimilarity/{model_name}/{dataset_name}/{annotation_label_key}_subset_{subset_label_string}_combined"
    ),
    label_embeddings=label_embeddings_subset,
    query_labels=subset_labels,
    true_labels=adata_subset.obs[annotation_label_key],
    label_key=annotation_label_key,
    save_format="svg",
    figsize=(2.5, 2.5),
    dpi=300,
    font_size=12,
    axis_tick_size=12,
    font_style="normal",
    font_weight="normal",
    axis_label_size=20,
    point_size=0.25,
    legend_layout="vertical",
    legend_point_size=20,
    umap_method="combined",
    label_min_distance=0.15,
    label_spring_strength=0.35,
    label_repulsion_strength=1.4,
)

In [None]:
result_subset

In [None]:
# Visualise the embeddings
from mmcontext.pl import plot_umap
from mmcontext.utils import consolidate_low_frequency_categories

current_key = label_key
adata_cut = consolidate_low_frequency_categories(adata, [current_key], threshold=50, remove=True)
emb_key = "mmcontext_emb"
plot_umap(
    adata,
    color_key=label_key,
    embedding_key=emb_key,
    save_format="svg",
    save_dir=f"figs/{model_name}/{dataset_name}",
    save_plot=False,
    title="",
)

In [None]:
# Visualise the embeddings
from mmcontext.pl import plot_umap
from mmcontext.utils import consolidate_low_frequency_categories

current_key = label_key
adata_cut = consolidate_low_frequency_categories(adata, [current_key], threshold=1, remove=False)
emb_key = layer_key
plot_umap(
    adata_cut,
    color_key=label_key,
    embedding_key=emb_key,
    save_format="svg",
    nametag="",
    save_dir=f"figs/{model_name}/{dataset_name}",
    save_plot=False,
    title="",
)

In [None]:
from mmcontext.eval.query_annotate import OmicsQueryAnnotator

annotator = OmicsQueryAnnotator(model)
annotator.annotate_omics_data(adata, full_query_labels, emb_key="mmcontext_emb")

In [None]:
# get accuracy of best label vs true label
from sklearn.metrics import accuracy_score

accuracy_score(adata.obs["best_label"], adata.obs[label_key])

In [None]:
queries_csv = "../../data/queries/additional_combined.csv"
if dataset_name == "hiha_100k" and os.path.exists(queries_csv):
    df = pd.read_csv(queries_csv)
    labels = df["Cell Type"]
    Definition = df["Definition"]
    from mmcontext.eval.query_annotate import OmicsQueryAnnotator
    from mmcontext.pl.plotting import plot_query_scores_with_labels_umap

    annotator = OmicsQueryAnnotator(model)
    annotator.query_with_text(adata, Definition, emb_key="mmcontext_emb")
    # Call the plotting function
    plot_query_scores_with_labels_umap(
        adata=adata,
        queries=Definition,
        labels=labels,
        label_key="AIFI_L2",
        save_dir=f"figs/{model_name}/{dataset_name}/umap_with_labels",
        nametag="",
        figsize=(4, 4),
        point_size=2,
        dpi=300,  # Lower DPI for faster generation
        axis_label_size=18,
        axis_tick_size=18,
    )