In [None]:
import os
import sys

import numpy as np
import pandas as pd
import torch
from omegaconf import OmegaConf

from preprocessing.generate_embeddings import generate_embeddings
from preprocessing.generate_summaries import generate_summaries
sys.path.append("src")
from classifier_model import ClassifierCRISPR

### Define Screen Parameters

In [None]:
use_summarized = False

perturbation = "inhibition"
organism = "human"
cell = "CD8+ T cells"
phenotype = "decreased cytokine secretion)"
gene_list = pd.read_excel("../custom/gene_list.xlsx")["gene"]

output_csv = "runs/custom-v1.csv"

### Define Constants

In [None]:
api_url = "https://api.openai.com/v1"
if "OPENAI_API_KEY" in os.environ:
    api_key = os.environ["OPENAI_API_KEY"]
else:
    api_key = input("OPENAI_API_KEY")
os.environ["OPENAI_API_KEY"] = api_key
summary_model = "gpt-4o-2024-11-20"
embedding_model = "text-embedding-3-large"

summ_id = "q33613fx"
unsumm_id = "v2zz1ph7"

summ_model = f"runs/classifier-summarized/{summ_id}/VirtualCRISPR/{summ_id}/checkpoints/last.ckpt"
unsumm_model = f"runs/classifier-unsummarized/{unsumm_id}/VirtualCRISPR/{unsumm_id}/checkpoints/last.ckpt"
summ_config = f"runs/classifier-summarized/{summ_id}/config.yaml"
unsumm_config = f"runs/classifier-unsummarized/{unsumm_id}/config.yaml"

cell_summary_prompt = "prompts/summary-cell.json"
phenotype_summary_prompt = "prompts/summary-phenotype.json"

genome_map = {
    "human": "../genomes/genome_homo_sapien.tsv",
    "mouse": "../genomes/genome_mus_musculus.tsv",
}

precomputed_original_gene_map = {
    "human": "data/embeddings/genes_human.npy",
    "mouse": "data/embeddings/genes_mouse.npy",
}
precomputed_summarized_gene_map = {
    "human": "data/embeddings/summarized_genes_human.npy",
    "mouse": "data/embeddings/summarized_genes_mouse.npy",
}

precomputed_original_method_path = "data/embeddings/methods.npy"
precomputed_summarized_method_path = "data/embeddings/summarized_methods.npy"

### Align Genes to Reference Genome

In [None]:
genome = pd.read_csv(genome_map[organism], sep="\t")
genome = genome[genome["Gene_Type"] == "PROTEIN_CODING"].reset_index(drop=True)
assert not genome["OFFICIAL_SYMBOL"].str.lower().duplicated().any()
genome.index = genome["OFFICIAL_SYMBOL"].str.lower()

gene_mask = gene_list.str.lower().isin(genome.index)
dupe_mask = gene_list.duplicated(keep=False)
mask = gene_mask & ~dupe_mask

aligned_genes = genome.loc[gene_list[mask].str.lower(), "IDENTIFIER_ID"]
print(f"Filter from {len(mask)} to {len(aligned_genes)}")

### (Pre-)Compute Embeddings

In [None]:
if use_summarized:
    cell = generate_summaries(terms=[cell], prompt_file=cell_summary_prompt, api_url=api_url, api_key=api_key, model=summary_model)[0]
    phenotype = generate_summaries(terms=[phenotype], prompt_file=phenotype_summary_prompt, api_url=api_url, api_key=api_key, model=summary_model)[0]

cell_emb = generate_embeddings(terms=[cell], api_url=api_url, api_key=api_key, model=embedding_model)[0]
phenotype_emb = generate_embeddings(terms=[phenotype], api_url=api_url, api_key=api_key, model=embedding_model)[0]

In [None]:
if not use_summarized:
    precomputed_methods = np.load(precomputed_original_method_path, allow_pickle=True).item()
    precomputed_genes = np.load(precomputed_original_gene_map[organism], allow_pickle=True).item()
else:
    precomputed_methods = np.load(precomputed_summarized_method_path, allow_pickle=True).item()
    precomputed_genes = np.load(precomputed_summarized_gene_map[organism], allow_pickle=True).item()

method_emb = precomputed_methods[perturbation.title()]
gene_embs = [precomputed_genes[gene_id] for gene_id in aligned_genes]

### Run Classification Over Embeddings

In [None]:
sd = torch.load(summ_model if use_summarized else unsumm_model, map_location="cpu")
sd = {k.replace("classifier.", ""): v for k, v in sd["state_dict"].items()}
cfg = OmegaConf.load(summ_config)["model"]["init_args"]

cls = ClassifierCRISPR(**cfg)
cls.load_state_dict(sd)
cls.eval()
None

In [None]:
gene_embs = torch.tensor(np.array(gene_embs))
cell_embs = torch.tensor(cell_emb).repeat(len(gene_embs), 1)
method_embs = torch.tensor(method_emb).repeat(len(gene_embs), 1)
phenotype_embs = torch.tensor(phenotype_emb).repeat(len(gene_embs), 1)

In [None]:
with torch.inference_mode():
    _, probs = cls(
        method_emb=method_embs,
        cell_emb=cell_embs,
        phenotype_emb=phenotype_embs,
        gene_emb=gene_embs,
    )

In [None]:
output = aligned_genes.reset_index()
output["PREDICTED_PROB"] = probs[:, 1]
output.to_csv(output_csv, index=False)
output