In [None]:
# start coding here
import scanpy as sc
import numpy as np
import pickle
import pandas as pd
from tqdm.notebook import tqdm
from cellwhisperer.utils.processing import ensure_raw_counts_adata
import random

from zero_shot_validation_scripts.dataset_preparation import load_and_preprocess_dataset

import openai
import anthropic
import logging

In [None]:
# Connect logging to file snakemake.log.progress
logging.basicConfig(
    level=logging.INFO,
    format="%(asctime)s - %(levelname)s - %(message)s",
    handlers=[logging.FileHandler(snakemake.log.progress)],
)

In [None]:
metadata_col = snakemake.wildcards.metadata_col
dataset_name = snakemake.wildcards.dataset
model = snakemake.params.model

In [None]:
if "claude" in model.lower():
    client = anthropic.Anthropic(api_key=snakemake.params.api_key, max_retries=5)
else:
    client = openai.OpenAI(
        api_key=snakemake.params.api_key,
        base_url=snakemake.params.api_base_url,
    )

In [None]:
### loading the geneformer normalization factors
with open(snakemake.input.gene_normalizers, "rb") as fp:
    gene_normalizers = pickle.load(fp)
adata = load_and_preprocess_dataset(
    dataset_name=dataset_name, read_count_table_path=snakemake.input.read_count_table
)
ensure_raw_counts_adata(adata)

In [None]:
common_genes = adata.var.reindex(gene_normalizers.keys()).dropna().index
common_gene_symbols = adata.var.gene_name.reindex(common_genes).values

In [None]:
gene_normalizers = np.exp(np.array([gene_normalizers[key] for key in common_genes]))
gene_normalizers.max()

In [None]:
adata_no_nans = adata[
    ~(adata.obs[metadata_col].isna()) & ~(adata.obs[metadata_col] == "nan"),
    common_genes,
].copy()

In [None]:
# compute the normalized gene expressions by [celltype] group and aggregate (code borrowed from `compute_top_genes.py.ipynb`)

top_genes_by_instance = {}
grouper = (
    adata_no_nans.obs.groupby(metadata_col)
    if snakemake.params.average_by_class
    else adata_no_nans.obs.groupby(adata_no_nans.obs.index)
)
for index, group in grouper:
    X_group = adata_no_nans[group.index].X
    try:
        X_group = X_group.toarray()
    except AttributeError:
        pass
    assert len(X_group.shape) == 2
    X_normed = X_group.sum(axis=0) / gene_normalizers
    assert len(X_normed.shape) == 1
    top_gene_indices = np.argsort(X_normed)[::-1][: snakemake.params.top_n_genes]

    top_genes_by_instance[index] = common_gene_symbols[top_gene_indices]

In [None]:
candidates = adata_no_nans.obs[metadata_col].drop_duplicates().tolist()

In [None]:
predictions = {}

for instance, top_genes in tqdm(top_genes_by_instance.items()):
    random.shuffle(
        candidates
    )  # LLMs are known to have position-preferences in such tasks. We average this out by shuffling before every generation run
    prompt = snakemake.params.prompt.format(
        candidates=", ".join(candidates), markers=", ".join(top_genes)
    )
    logging.info(f"Prompt for instance {instance}: {prompt}")

    if "claude" in model.lower():
        chat_completion = client.messages.create(
            model=model,
            temperature=0,
            max_tokens=50,
            messages=[{"role": "user", "content": prompt}],
        )
        predictions[instance] = chat_completion.content[0].text
    else:  # OpenAI API
        chat_completion = client.chat.completions.create(
            messages=[
                {
                    "role": "user",
                    "content": prompt,
                }
            ],
            model=model,
        )
        predictions[instance] = chat_completion.choices[0].message.content
    logging.info(f"Prediction for instance {instance}: {predictions[instance]}")

In [None]:
df = pd.Series(predictions, name="predicted_labels").to_frame()
df["valid_prediction"] = (
    df["predicted_labels"].str.lower().isin([v.lower() for v in candidates])
)
df["label"] = (
    df.index if snakemake.params.average_by_class else adata_no_nans.obs[metadata_col]
)
df["is_correct"] = df["predicted_labels"].str.lower() == df["label"].str.lower()
df["valid_prediction"].value_counts()

In [None]:
df["is_correct"].value_counts()

In [None]:
try:
    df.to_csv(snakemake.output.predictions)
except OSError:
    import os

    os.makedirs(os.path.dirname(snakemake.output.predictions))
    df.to_csv(snakemake.output.predictions)