In [None]:
import os
import re
import json
from pathlib import Path

import sys
sys.path.append("../benchmarks")

import numpy as np
import pandas as pd
import torch
from openai import OpenAI
from torch.nn.functional import softmax

from benchmark import load_benchmark
from src.mlp import MLP

api_url = "https://api.openai.com/v1"
api_key = input("OPENAI_API_KEY")
summary_model = "gpt-4o-2024-11-20"
embedding_model = "text-embedding-3-large"
benchmark_path = Path("../benchmarks/benchmark-difficult")

## Load Benchmark Data

In [None]:
benchmark = load_benchmark(benchmark_path)

## Generate Embeddings (of Summaries)

In [None]:
gene_emb_map = {
    "mouse": np.load("data/embeddings/genes_mouse.npy", allow_pickle=True).item(),
    "human": np.load("data/embeddings/genes_human.npy", allow_pickle=True).item(),
}
summ_gene_emb_map = {
    "mouse": np.load("data/embeddings/summarized_genes_mouse.npy", allow_pickle=True).item(),
    "human": np.load("data/embeddings/summarized_genes_human.npy", allow_pickle=True).item(),
}
method_emb = np.load("data/embeddings/methods.npy", allow_pickle=True).item()
summ_method_emb = np.load("data/embeddings/summarized_methods.npy", allow_pickle=True).item()

In [None]:
experiments = benchmark.drop_duplicates("screen_file")
if not os.path.exists("data/eval/summaries.npy"):
    client = OpenAI(base_url=api_url, api_key=api_key)
    summary_prompt_files = {
        "cell": "prompts/summary-cell.json",
        "phenotype": "prompts/summary-phenotype.json",
    }

    summaries = dict()
    for col, prompt_file in summary_prompt_files.items():
        summaries[col] = dict()
        with open(prompt_file) as f:
            prompt_template = f.read()
        for i, experiment in experiments.iterrows():
            term = experiment[col]
            prompt = json.loads(re.sub("\{.*\}", term, prompt_template))
            completion = client.chat.completions.create(
                messages=prompt,
                model=summary_model,
                seed=42,
                n=1,
                temperature=0,
                max_tokens=2048,
            )
            summary = completion.choices[0].message.content
            summaries[col][term] = summary
    np.save("data/eval/summaries.npy", summaries)
else:
    summaries = np.load("data/eval/summaries.npy", allow_pickle=True).item()

In [None]:
if (
    not os.path.exists("data/eval/summary_embeddings.npy") or
    not os.path.exists("data/eval/term_embeddings.npy")
):
    term_embeddings = dict()
    summary_embeddings = dict()
    client = OpenAI(base_url=api_url, api_key=api_key)
    for col, term_summaries in summaries.items():
        term_embeddings[col] = dict()
        summary_embeddings[col] = dict()
        for term, summary in term_summaries.items():
            out = client.embeddings.create(
                input=[term, summary],
                model=embedding_model,
            )
            term_emb = np.asarray(out.data[0].embedding)
            summ_emb = np.asarray(out.data[1].embedding)
            term_embeddings[col][term] = term_emb
            summary_embeddings[col][term] = summ_emb
    np.save("data/eval/summary_embeddings.npy", summary_embeddings)
    np.save("data/eval/term_embeddings.npy", term_embeddings)
else:
    summary_embeddings = np.load("data/eval/summary_embeddings.npy", allow_pickle=True).item()
    term_embeddings = np.load("data/eval/term_embeddings.npy", allow_pickle=True).item()

## Run Classification Over Embeddings

In [None]:
human_genome = pd.read_csv("../genomes/genome_homo_sapiens.tsv", sep="\t")
human_genome = human_genome[human_genome["Gene_Type"] == "PROTEIN_CODING"].reset_index(drop=True)
human_genome["OFFICIAL_SYMBOL"] = human_genome["OFFICIAL_SYMBOL"].str.lower()
mouse_genome = pd.read_csv("../genomes/genome_mus_musculus.tsv", sep="\t")
mouse_genome = mouse_genome[mouse_genome["Gene_Type"] == "PROTEIN_CODING"].reset_index(drop=True)
mouse_genome["OFFICIAL_SYMBOL"] = mouse_genome["OFFICIAL_SYMBOL"].str.lower()

genome_map = {
    "human": human_genome,
    "mouse": mouse_genome,
}

In [None]:
summ_model = "q33613fx"
unsumm_model = "v2zz1ph7"

def load_model(use_summarized):
    if use_summarized:
        sd = torch.load(f"runs/classifier-summarized/{summ_model}/VirtualCRISPR/{summ_model}/checkpoints/last.ckpt", map_location="cpu")
    else:
        sd = torch.load(f"runs/classifier-unsummarized/{unsumm_model}/VirtualCRISPR/{unsumm_model}/checkpoints/last.ckpt", map_location="cpu")
    sd = {k.replace("classifier.mlp.", ""): v for k, v in sd["state_dict"].items()}

    cls = MLP(
        input_dim=3072*4,
        reduction_factor=2,
        n_hidden=4,
        output_dim=2,
    )
    cls.load_state_dict(sd)
    cls.eval()

    return cls

In [None]:
results = dict()
for label, use_summarized in [("Summarized", True), ("Unsummarized", False)]:
    print(f"================= {label} =================")
    cls = load_model(use_summarized)
    preds = []
    trues = []
    for (_, method, cell, organism, phenotype), _hits in benchmark.groupby(["screen_file", "perturbation", "cell", "organism", "phenotype"]):
        method = method.title()
        genome = genome_map[organism]

        if use_summarized:
            me = torch.as_tensor(summ_method_emb[method], dtype=torch.float32)
            pe = torch.as_tensor(summary_embeddings["phenotype"][phenotype], dtype=torch.float32)
            ce = torch.as_tensor(summary_embeddings["cell"][cell], dtype=torch.float32)
        else:
            me = torch.as_tensor(method_emb[method], dtype=torch.float32)
            pe = torch.as_tensor(term_embeddings["phenotype"][phenotype], dtype=torch.float32)
            ce = torch.as_tensor(term_embeddings["cell"][cell], dtype=torch.float32)

        X = []
        for gene, hit in zip(_hits["gene"], _hits["hit"]):
            gene_id = genome.loc[genome["OFFICIAL_SYMBOL"] == gene.lower(), "IDENTIFIER_ID"].iloc[0]
            if use_summarized:
                ge = summ_gene_emb_map[organism][gene_id]
            else:
                ge = gene_emb_map[organism][gene_id]
            ge = torch.as_tensor(ge, dtype=torch.float32)
            X.append(torch.concat([me, ce, pe, ge], dim=0))
            trues.append(hit)
        X = torch.stack(X)

        with torch.inference_mode():
            logits = cls(X)
        pred = softmax(logits, dim=1)
        preds.extend(pred[:, 1].numpy())
    trues = np.asarray(trues)
    preds = np.asarray(preds)
    results[label] = (trues, preds)

In [None]:
from sklearn.metrics import roc_curve, auc, precision_recall_curve
import matplotlib.pyplot as plt

metrics = dict()
fig, (ax1, ax2) = plt.subplots(nrows=1, ncols=2, figsize=(10, 5))

for label, (trues, preds) in results.items():
    fpr, tpr, thresholds = roc_curve(trues, preds)
    auroc = auc(fpr, tpr)
    ax1.plot(fpr, tpr, label=f"{label} (AUROC={auroc:0.3f})")

    precision, recall, _ = precision_recall_curve(trues, preds)
    auprc = auc(recall, precision)
    ax2.plot(recall, precision, label=f"{label} (AUPRC={auprc:0.3f})")

    optimal_idx = np.argmax(tpr - fpr)
    optimal_threshold = thresholds[optimal_idx]
    preds_bin = (preds > optimal_threshold).astype(int)

    true_neg_pred_neg = ((trues == 0) & (preds_bin == 0)).sum()
    true_neg_pred_pos = ((trues == 0) & (preds_bin == 1)).sum()
    true_pos_pred_neg = ((trues == 1) & (preds_bin == 0)).sum()
    true_pos_pred_pos = ((trues == 1) & (preds_bin == 1)).sum()

    tp = true_pos_pred_pos
    tn = true_neg_pred_neg
    fp = true_neg_pred_pos
    fn = true_pos_pred_neg

    metrics[label] = {
        "n-binary-pred": tp + tn + fp + fn,
        "f1": (2 * tp) / (2 * tp + fp + fn),
        "ppv": tp / (tp + fp),
        "npv": tn / (fn + tn),
        "sensitivity": tp / (tp + fn),
        "specificity": tn / (fp + tn),
        "fpr": fp / (fp + tn),
        "auroc": auroc,
        "auprc": auprc,
    }


ax1.plot([0, 1], [0, 1], label=f"Random (AUROC=0.500)", linestyle="dashed", color="red", alpha=0.5)
pr_chance = trues.sum() / len(trues)
ax2.plot([0, 1], [pr_chance, pr_chance], label=f"Random (AUPRC={pr_chance:0.3f})", linestyle="dashed", color="red", alpha=0.5)

ax1.legend(title="Embeddings")
ax1.set(
    xlabel="1 - Specificity",
    ylabel="Sensitivity",
    title="CRISPR Classifier: ROC",
)

ax2.legend(title="Embeddings")
ax2.set(
    xlabel="Recall",
    ylabel="Precision",
    title="CRISPR Classifier: PR",
)

In [None]:
print(pd.DataFrame(metrics).T[["auroc", "auprc", "f1", "fpr"]].to_latex(float_format="%.2f"))