In [None]:
import os
import re
import json
from collections import defaultdict
from pathlib import Path

import sys
sys.path.append("../benchmarks")

import numpy as np
import pandas as pd
import torch
from openai import OpenAI
from sklearn.metrics.pairwise import cosine_similarity

from benchmark import load_benchmark
from src.mlp import MLP

api_url = "https://api.openai.com/v1"
api_key = input("OPENAI_API_KEY")
summary_model = "gpt-4o-2024-11-20"
embedding_model = "text-embedding-3-large"
benchmark_path = Path("../benchmarks/benchmark-difficult")

## Load Benchmark Data

In [None]:
whitelist = [
    "SCREEN_18_HITS_ONLY.tsv",
    "SCREEN_18_HITS_ONLY_FOR_INVERSE.tsv",
]

In [None]:
benchmark = load_benchmark(benchmark_path)
benchmark = benchmark[benchmark["screen_file"].isin(whitelist)].reset_index(drop=True)

## Generate Embeddings (of Summaries)

In [None]:
gene_emb_map = {
    "mouse": np.load("data/embeddings/genes_mouse.npy", allow_pickle=True).item(),
    "human": np.load("data/embeddings/genes_human.npy", allow_pickle=True).item(),
}
summ_gene_emb_map = {
    "mouse": np.load("data/embeddings/summarized_genes_mouse.npy", allow_pickle=True).item(),
    "human": np.load("data/embeddings/summarized_genes_human.npy", allow_pickle=True).item(),
}
method_emb = np.load("data/embeddings/methods.npy", allow_pickle=True).item()
summ_method_emb = np.load("data/embeddings/summarized_methods.npy", allow_pickle=True).item()

In [None]:
experiments = benchmark.drop_duplicates("screen_file")
if not os.path.exists("data/eval/summaries.npy"):
    client = OpenAI(base_url=api_url, api_key=api_key)
    summary_prompt_files = {
        "cell": "prompts/summary-cell.json",
        "phenotype": "prompts/summary-phenotype.json",
    }

    summaries = dict()
    for col, prompt_file in summary_prompt_files.items():
        summaries[col] = dict()
        with open(prompt_file) as f:
            prompt_template = f.read()
        for i, experiment in experiments.iterrows():
            term = experiment[col]
            prompt = json.loads(re.sub("\{.*\}", term, prompt_template))
            completion = client.chat.completions.create(
                messages=prompt,
                model=summary_model,
                seed=42,
                n=1,
                temperature=0,
                max_tokens=2048,
            )
            summary = completion.choices[0].message.content
            summaries[col][term] = summary
    np.save("data/eval/summaries.npy", summaries)
else:
    summaries = np.load("data/eval/summaries.npy", allow_pickle=True).item()

In [None]:
if (
    not os.path.exists("data/eval/summary_embeddings.npy") or
    not os.path.exists("data/eval/term_embeddings.npy")
):
    term_embeddings = dict()
    summary_embeddings = dict()
    client = OpenAI(base_url=api_url, api_key=api_key)
    for col, term_summaries in summaries.items():
        term_embeddings[col] = dict()
        summary_embeddings[col] = dict()
        for term, summary in term_summaries.items():
            out = client.embeddings.create(
                input=[term, summary],
                model=embedding_model,
            )
            term_emb = np.asarray(out.data[0].embedding)
            summ_emb = np.asarray(out.data[1].embedding)
            term_embeddings[col][term] = term_emb
            summary_embeddings[col][term] = summ_emb
    np.save("data/eval/summary_embeddings.npy", summary_embeddings)
    np.save("data/eval/term_embeddings.npy", term_embeddings)
else:
    summary_embeddings = np.load("data/eval/summary_embeddings.npy", allow_pickle=True).item()
    term_embeddings = np.load("data/eval/term_embeddings.npy", allow_pickle=True).item()

## Project and Contrast Embeddings

In [None]:
human_genome = pd.read_csv("../genomes/genome_homo_sapiens.tsv", sep="\t")
human_genome = human_genome[human_genome["Gene_Type"] == "PROTEIN_CODING"].reset_index(drop=True)
mouse_genome = pd.read_csv("../genomes/genome_mus_musculus.tsv", sep="\t")
mouse_genome = mouse_genome[mouse_genome["Gene_Type"] == "PROTEIN_CODING"].reset_index(drop=True)

genome_map = {
    "human": human_genome,
    "mouse": mouse_genome,
}

In [None]:
summ_model = "27y1lds0"
unsumm_model = "u30kvgtv"

def load_models(use_summarized):
    if use_summarized:
        sd = torch.load(f"runs/contrastive-summarized/{summ_model}/VirtualCRISPR/{summ_model}/checkpoints/last.ckpt", map_location="cpu")
    else:
        sd = torch.load(f"runs/contrastive-unsummarized/{unsumm_model}/VirtualCRISPR/{unsumm_model}/checkpoints/last.ckpt", map_location="cpu")
    exp_sd = {k.replace("contraster.exp_proj.", ""): v for k, v in sd["state_dict"].items() if "exp_proj" in k}
    gene_sd = {k.replace("contraster.gene_proj.", ""): v for k, v in sd["state_dict"].items() if "gene_proj" in k}

    exp_proj = MLP(
        input_dim=3072*3,
        reduction_factor=3,
        n_hidden=2,
        output_dim=512,
    )
    exp_proj.load_state_dict(exp_sd)
    exp_proj.eval()

    gene_proj = MLP(
        input_dim=3072,
        reduction_factor=2,
        n_hidden=2,
        output_dim=512,
    )
    gene_proj.load_state_dict(gene_sd)
    gene_proj.eval()

    return exp_proj, gene_proj

In [None]:
for use_summarized in [True, False]:
    print(f"================= {'Summarized' if use_summarized else 'Unsummarized'} =================")
    exp_proj, gene_proj = load_models(use_summarized)
    accs = defaultdict(list)
    for (_, method, cell, organism, phenotype), _hits in benchmark.groupby(["screen_file", "perturbation", "cell", "organism", "phenotype"]):
        method = method.title()
        genome = genome_map[organism]

        if use_summarized:
            me = torch.as_tensor(summ_method_emb[method], dtype=torch.float32)
            pe = torch.as_tensor(summary_embeddings["phenotype"][phenotype], dtype=torch.float32)
            ce = torch.as_tensor(summary_embeddings["cell"][cell], dtype=torch.float32)
        else:
            me = torch.as_tensor(method_emb[method], dtype=torch.float32)
            pe = torch.as_tensor(term_embeddings["phenotype"][phenotype], dtype=torch.float32)
            ce = torch.as_tensor(term_embeddings["cell"][cell], dtype=torch.float32)
        exp_emb = torch.concat([me, ce, pe], dim=0).unsqueeze(dim=0)

        ges = []
        for j, row in genome.iterrows():
            gene = row["OFFICIAL_SYMBOL"]
            gene_id = row["IDENTIFIER_ID"]
            if use_summarized:
                ge = summ_gene_emb_map[organism][gene_id]
            else:
                ge = gene_emb_map[organism][gene_id]
            ges.append(ge)
        ges = np.asarray(ges)
        ges = torch.as_tensor(ges, dtype=torch.float32)

        with torch.inference_mode():
            exp_out = exp_proj(exp_emb)
            gene_out = gene_proj(ges)
        sims = cosine_similarity(exp_out, gene_out).squeeze()
        idxs = sims.argsort()[::-1] # most similar first
        ranked_genes = genome.loc[idxs]
        _hits = _hits.loc[_hits["hit"] == 1, "gene"]
        _hits = _hits.str.lower()
        for n in [5, 10, 50]:
            pred = ranked_genes.iloc[:n]["OFFICIAL_SYMBOL"].str.lower()
            acc = pred.isin(_hits).sum() / n
            accs[n].append(acc)
            # print(f"Acc@{n}: {acc:0.3f}")
    for n, _accs in accs.items():
        print(f"Acc@{n}: {np.asarray(_accs).mean()*100:0.1f}%")