In [None]:
# start coding here
import numpy as np
from cellwhisperer.jointemb.cellwhisperer_lightning import TranscriptomeTextDualEncoderLightning
from cellwhisperer.utils.model_io import load_cellwhisperer_model
from cellwhisperer.utils.inference import score_transcriptomes_vs_texts
import pandas as pd
from scipy.stats import pearsonr
from tqdm.auto import tqdm
import torch
import anndata
import matplotlib
import matplotlib.pyplot as plt

matplotlib.style.use(snakemake.input.mpl_style)


In [None]:
data = anndata.read_h5ad(snakemake.input.dataset)
pl_model, tokenizer, transcriptome_processor = load_cellwhisperer_model(snakemake.input.model)

In [None]:
if snakemake.wildcards.target_cluster != "":
    data = data[data.obs.cluster_label == snakemake.wildcards.target_cluster]

In [None]:
transcriptome_embeds = torch.from_numpy(data.obsm["transcriptome_embeds"])

text_embeds = pl_model.model.embed_texts([snakemake.params.search_term])

scores, _ = score_transcriptomes_vs_texts(
    transcriptome_input=transcriptome_embeds,
    text_list_or_text_embeds=text_embeds,
    logit_scale=pl_model.model.discriminator.temperature.exp().item(),
    average_mode=None,
    batch_size=64,
    score_norm_method=None,
    grouping_keys=None, # data.obs[obs_index_col_name].astype(str).values,
)
series = pd.Series(scores[0].cpu().detach())


In [None]:
obs = data.obs.copy()

In [None]:
obs["stem_cell_response"] = series.values

In [None]:
obs.condition.value_counts()

In [None]:
import seaborn as sns

fig, ax = plt.subplots(figsize=(1.4, 1.3))

palette = dict(zip(["inflamed", "healthy", "non-inflamed"], ["orange", "SkyBlue", "LightGreen"]))


sns.ecdfplot(data=obs, x="stem_cell_response", hue="condition", ax=ax, palette=palette)

In [None]:
fig, ax = plt.subplots(figsize=(3, 1))
sns.violinplot(data=obs.query("condition.isin(['inflamed', 'non-inflamed'])"), x="stem_cell_response", y="condition", legend=False, order=["inflamed", "non-inflamed"], color="gray")

fig.savefig(snakemake.output.plot)

In [None]:
obs[obs.condition.isin(['inflamed', 'non-inflamed'])]

In [None]:
from scipy.stats import ks_2samp

stat, pval = ks_2samp(obs.loc[obs.condition == "inflamed", "stem_cell_response"], obs.loc[obs.condition == "non-inflamed", "stem_cell_response"])

print(stat, pval)