In [None]:
# start coding here
import numpy as np
from cellwhisperer.jointemb.cellwhisperer_lightning import TranscriptomeTextDualEncoderLightning
from cellwhisperer.utils.model_io import load_cellwhisperer_model
from cellwhisperer.utils.inference import score_transcriptomes_vs_texts
import pandas as pd
from scipy.stats import pearsonr
from tqdm.auto import tqdm
import torch

In [None]:
# copy from /home/moritz/Projects/cellwhisperer/src/post_clip_processing/notebooks/cellwhisperer_annotate_clusters.py.ipynb

In [None]:
pl_model, tokenizer, transcriptome_processor = load_cellwhisperer_model(snakemake.input.model)

In [None]:
processed = np.load(snakemake.input.processed_dataset, allow_pickle=True)    
gsva_results=pd.read_parquet(snakemake.input.gsva_results).set_index("Unnamed: 0")
gsva_results.index = gsva_results.index.str.replace(r'\s*\(GO:\d+\)', '', regex=True)

gsva_library = gsva_results.pop("library")

In [None]:
assert len(set(gsva_results.columns) - set(processed["orig_ids"]))  == 0, "It is expected that the GSVA subset is a subset of the full dataset"

In [None]:
# select the right embeddings

mapping = pd.Series(index=processed["orig_ids"], data=list(range(len(processed["orig_ids"]))))
indices = mapping[gsva_results.columns].values
indices


In [None]:
transcriptome_embeds = torch.from_numpy(processed["transcriptome_embeds"][indices])

In [None]:
scores, _ = score_transcriptomes_vs_texts(
    transcriptome_input=transcriptome_embeds,
    text_list_or_text_embeds=gsva_results.index.to_list(),
    average_mode=None,  # compute all vs. all
    logit_scale=pl_model.model.discriminator.temperature.exp(),
    model=pl_model.model,
    score_norm_method=None)

In [None]:
assert all(np.array(scores.shape) == gsva_results.shape)
scores_df = pd.DataFrame(data=scores, index=gsva_results.index, columns=gsva_results.columns)
scores_df.to_parquet(snakemake.output.cw_transcriptome_term_scores)