## Train a new mmcontext model (demonstrated on proteomics data)


In [None]:
from mmcontext.utils import setup_logging

setup_logging()

In [None]:
import anndata as ad

from mmcontext.file_utils import download_file_from_share_link

# a protein dataset from figshare (https://plus.figshare.com/articles/dataset/scPerturb_Single-Cell_Perturbation_Data_RNA_and_protein_h5ad_files/24160713?utm_source=chatgpt.com&file=42428325)
data_link = "https://plus.figshare.com/ndownloader/files/42428325"
local_path = "Frangiehlzar2021_protein.h5ad"
# download the data
download_file_from_share_link(share_link=data_link, save_path=local_path)
# load the data
adata = ad.read_h5ad(local_path)

Of course one can create a more sophisicated description, this is just an example

In [None]:
# Create a description string for each cell by looping over the rows in adata.obs
def make_description(row):
    """Make a quick description of a cell based on its metadata"""
    return (
        f"The first perturbation is {row['perturbation']} "
        f"and the second perturbation is {row['perturbation_2']}."
        f" The tissue is {row['tissue_type']} and it has cancer yes or no: {row['cancer']}."
        f" The disease is {row['disease']}."
        f" The celltype is {row['celltype']}."
    )


# Add a "description" column to adata.obs using the function above
adata.obs["description"] = adata.obs.apply(make_description, axis=1)
# Also add a sample index column for later
adata.obs["sample_idx"] = adata.obs.index

In [None]:
import numpy as np

# let's split by train and val, randomly 80% train
adata.obs["split"] = np.random.rand(len(adata)) < 0.8
adata_train = adata[adata.obs["split"]].copy()
adata_val = adata[~adata.obs["split"]].copy()

In [None]:
adata_val.shape

In [None]:
adata_train.shape

In [None]:
import scanpy as sc

# normalise and log transform the data
sc.pp.normalize_total(adata_train, inplace=True)
sc.pp.log1p(adata_train)

sc.pp.normalize_total(adata_val, inplace=True)
sc.pp.log1p(adata_val)

In [None]:
# since these datasets only contain 24 proteins, we will just use their expression as embeddings
# we will use the protein names as the embedding keys
adata_train.obsm["X_prot"] = adata_train.X
adata_val.obsm["X_prot"] = adata_val.X
processed_paths = {"train": "Frangiehlzar2021_protein_pp_train.h5ad", "val": "Frangiehlzar2021_protein_pp_val.h5ad"}
adata_train.write_h5ad(processed_paths["train"])
adata_val.write_h5ad(processed_paths["val"])

In [None]:
from adata_hf_datasets import AnnDataSetConstructor
from datasets import DatasetDict

ds_dict = DatasetDict()
# multiplets format for training datasets (with descriptions)
# A sentence key refers to the column in adata.obs that is used to represent the sample.
# For numeric data, we use the sample index, and later register the created embedding linked to their indices in the tokenizer
for split_name, adata_split in {"train": adata_train, "val": adata_val}.items():
    constructor = AnnDataSetConstructor(dataset_format="multiplets", resolve_negatives=True)
    constructor.add_anndata(
        adata_split,
        caption_key="description",
        sentence_keys=["sample_idx"],
        adata_link=processed_paths[split_name],
        batch_key="library_preparation_protocol",  # In this case all are from the same batch, but providing a batch key can improve batch integration by negative sampling
    )
    ds = constructor.get_dataset()
    ds_dict[split_name] = ds

In [None]:
ds["adata_link"]

## Configure the Model
The custom sentence transformers model will allow training with the sentence transformers Trainer. The numeric data has to be registered with the model, such that the representations in cell_sentence_1 (cell token eg. sample indices) are linked to the respective numeric vector, which serves as the initial repesentation of that sample. 
This is achieved by building a lookup table (cell token -> id) and a frozen embedding layer (id --> numeric vector). 

In [None]:
from sentence_transformers import SentenceTransformer

from mmcontext.mmcontextencoder import MMContextEncoder

enc = MMContextEncoder(
    text_encoder_name="sentence-transformers/all-MiniLM-L6-v2",
    adapter_hidden_dim=128,
    adapter_output_dim=64,
    freeze_text_encoder=True,
    unfreeze_last_n_layers=2,
    output_token_embeddings=False,
    train_lookup=False,
    joint_adapter_hidden_dim=None,
    text_model_kwargs=None,
    use_text_adapter=True,
)
model = SentenceTransformer(modules=[enc])

In [None]:
token_df, _ = model[0].get_initial_embeddings_from_adata_link(
    ds_dict,
    layer_key="X_prot",
    download_dir="data_cache",
    axis="obs",  # since we get embeddings from adata.obsm. We could also use "varm" and for example use an embedding for each protein
)
model[0].register_initial_embeddings(token_df, data_origin="prot")

In [None]:
# the model expects a certain prefix on the cell tokens.
model[0].processor.prefix

In [None]:
# you could add this manually or use the method below
model[0].prefix_ds(ds_dict, columns_to_prefix=["cell_sentence_1"])

In [None]:
# lastly we have to drop some columns from the dataset and
# rename the main column to "anchor".
# you might think that this is a bit cumbersome, which it is.
# But this setup allowed for fleixble training,
# using either cell or feature level tokens, using text based cell sentences
# or numeric embeddings and resolving negatives
# to whatever column was chosen for training.
# That means that for a certain training run, the same dataset can be reused, and
# only modified differently. But in the end, it is a bit of work to set up.
ds_final = ds_dict.rename_column("cell_sentence_1", "anchor")
ds_final = ds_final.remove_columns(["sample_idx", "adata_link", "negative_1_idx"])
ds_final

In [None]:
from sentence_transformers import SentenceTransformerTrainer, SentenceTransformerTrainingArguments

args = SentenceTransformerTrainingArguments(
    num_train_epochs=1,
    per_device_train_batch_size=128,
    per_device_eval_batch_size=128,
    learning_rate=1e-5,
    warmup_ratio=0.1,
    fp16=False,
    bf16=False,
    eval_strategy="steps",
    eval_steps=100,
    save_strategy="steps",
    save_steps=100,
    save_total_limit=1,
    max_grad_norm=1.0,
    logging_steps=10,
    run_name="protein_test",
)

In [None]:
from sentence_transformers.evaluation import TripletEvaluator
from sentence_transformers.losses import MultipleNegativesRankingLoss

loss = MultipleNegativesRankingLoss(model)
evaluator = TripletEvaluator(
    anchors=ds_final["val"]["anchor"],
    positives=ds_final["val"]["positive"],
    negatives=ds_final["val"]["negative_1"],
    name="val",
)

In [None]:
trainer = SentenceTransformerTrainer(
    model=model,
    args=args,
    train_dataset=ds_final["train"],
    eval_dataset=ds_final["val"],
    loss=loss,
    evaluator=evaluator,
)

In [None]:
trainer.train()