In [None]:
# start coding here
import anndata
import pandas as pd
import torch
from cellwhisperer.jointemb.cellwhisperer_lightning import (
    TranscriptomeTextDualEncoderLightning,
)
from cellwhisperer.config import get_path, model_path_from_name
from cellwhisperer.utils.inference import (
    score_transcriptomes_vs_texts,
    rank_terms_by_score,
    prepare_terms,
)

In [None]:
import cellwhisperer

In [None]:
from typing import Tuple
from cellwhisperer.jointemb.processing import TranscriptomeTextDualEncoderProcessor
from cellwhisperer.jointemb.cellwhisperer_lightning import (
    TranscriptomeTextDualEncoderLightning,
)
from pathlib import Path
from cellwhisperer.utils.model_io import load_cellwhisperer_model

In [None]:
adata = anndata.read_h5ad(snakemake.input.adata)

In [None]:
# load model
modelpath = snakemake.input.model

# TODO replace model path

pl_model, tokenizer, transcriptome_processor = load_cellwhisperer_model(modelpath)

In [None]:
adata.obs["index_int"] = list(range(len(adata.obs)))

In [None]:
# TODO group by cluster

grouped_embeddings = adata.obs.groupby("leiden").apply(
    lambda group: adata.X[group.index_int].mean(axis=0)
)

In [None]:
import numpy as np

mean_embeddings = torch.from_numpy(np.stack(grouped_embeddings.values)).to(
    pl_model.device
)

In [None]:
dfs = []
for leiden, group in adata.obs.groupby("leiden"):
    group_embeds = torch.from_numpy(adata.X[group.index_int]).to(pl_model.device)

    terms_df = prepare_terms(snakemake.input.terms)

    scores, _ = score_transcriptomes_vs_texts(
        transcriptome_input=group_embeds,
        text_list_or_text_embeds=terms_df["term"].to_list(),
        logit_scale=pl_model.model.discriminator.temperature.exp(),
        model=pl_model.model,
        transcriptome_processor=transcriptome_processor,
        average_mode="embeddings",
        score_norm_method="zscore",
    )
    similarity_scores_df = rank_terms_by_score(scores, terms_df)
    similarity_scores_df["leiden"] = leiden
    dfs.append(similarity_scores_df)
similarity_scores_df = pd.concat(dfs)

In [None]:
similarity_scores_df.to_csv(snakemake.output.csv)