In [None]:
import pandas as pd
import torch
import anndata
import numpy as np
import seaborn as sns
import matplotlib.pyplot as plt
from cellwhisperer.utils.inference import score_transcriptomes_vs_texts

In [None]:
query_variants = pd.read_csv(snakemake.input.query_variants, index_col=0)
text_embeddings = torch.load(
    snakemake.input.text_embeddings, map_location=torch.device("cpu")
)
cell_dataset = np.load(snakemake.input.dataset)
cell_embeddings = torch.from_numpy(cell_dataset["transcriptome_embeds"])

In [None]:
cell_embeddings.shape

In [None]:
text_embeddings.shape

In [None]:
# Compute CellWhisperer scores for each variant for each cell

scores, grouping_keys = score_transcriptomes_vs_texts(
    transcriptome_input=cell_embeddings,
    text_list_or_text_embeds=text_embeddings,
    logit_scale=16.7482,  # corresponds to `model.discriminator.temperature.exp()` for cellwhisperer_clip_v1
    average_mode=None,  # Could use "embeddings" if grouping
    grouping_keys=None,  # Could group by cell type or so to reduce computational complexity
    score_norm_method=None,  # score_norm_method: "zscore", "softmax", "01norm" or None
)

In [None]:
scores.corrcoef()

In [None]:
query_variants.index.drop_duplicates()

In [None]:
unique_values = query_variants.index.drop_duplicates()
palette = sns.color_palette("Greys", len(unique_values))

color_mapping = {value: palette[i] for i, value in enumerate(unique_values)}

# Create a list of colors corresponding to the index
row_colors = query_variants.index.map(color_mapping).values

palette

In [None]:
scores.shape

In [None]:
query_variants

In [None]:
scores_coeff = pd.DataFrame(
    data=scores.corrcoef(), index=query_variants.index, columns=query_variants.index
)
scores_coeff

In [None]:
snakemake.output.plot

In [None]:
# Per-variant clustermap annotated by the original query

sns.clustermap(scores_coeff, row_colors=row_colors, figsize=(6, 6))

plt.savefig(snakemake.output.plot)
plt.savefig(snakemake.output.plot.replace(".png", ".svg"))