### _in silico_ perturbation by cell type prompting

In [None]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

# for flex attention
import torch._dynamo
torch._dynamo.config.suppress_errors = True

DEVICE = torch.device('cuda:1')
sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, \
    GeneNetworkAnalysisBase

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"
CHECKPOINT_PATH = "/mnt/cellariumgpt-xfer/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

### Internal consistency check

Has the model learned to generate cell-type-specific gene expression patterns?

Experiment: Prompt with a given cell type, query the gene expression, give the gene expression back as prompt, ask the model to predict cell type.

In [None]:
query_gene_ids_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "query_gene_ids.csv")
query_gene_ids = pd.read_csv(query_gene_ids_path, header=None).values.flatten().tolist()[:5000]  # NOTE

assay = "10x 3' v3"
suspension_type = "cell"
prompt_metadata_dict = {
    "cell_type": "neuron",
    "tissue": "blood"
}
total_mrna_umis = 5_000

with torch.inference_mode():

    tokens_dict, context_indices, _, _ = ctx.generate_gene_tokens_by_metadata(
        assay=assay,
        suspension_type=suspension_type,
        prompt_metadata_dict=prompt_metadata_dict,
        total_mrna_umis=total_mrna_umis,
        query_gene_ids=query_gene_ids,
        perturb_gene_ids=None
    )

    gene_marginal_means_nq, gene_marginal_std_nq = ctx.get_marginal_mean_std_from_tokens(tokens_dict, context_indices)

    # make a prop AnnData, inject back the counts we got from CellariumGPT, predict cell type
    prop_adata_feedback = ctx._adata.copy()
    prop_adata_feedback = prop_adata_feedback[0, query_gene_ids].copy()

    prop_adata_feedback.X[0, :] = gene_marginal_means_nq[0, :].cpu().numpy()
    prop_adata_feedback.obs["total_mrna_umis"] = [total_mrna_umis]

    metadata_prediction_dict = ctx.predict_metadata(prop_adata_feedback)

    metadata_key = 'cell_type'
    probs_k = metadata_prediction_dict[metadata_key][0]
    sort_order = np.argsort(probs_k)[::-1]

    for k in range(10):
        prob = probs_k[sort_order[k]]
        name = ctx.metadata_ontology_infos[metadata_key]["labels"][sort_order[k]]
        print(f"{name}: {prob:.2f}")

- We have two different generalization issues:
  - Large prompts (not an issue for this task)
  - Sensitivity to total mRNA UMIs (we need to find the mean/median for the UMIs for the cell type we're prompting)

### Perform _in silico_ deletion

In [None]:
from tqdm.notebook import tqdm


query_gene_ids_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "query_gene_ids.csv")
query_gene_ids = pd.read_csv(query_gene_ids_path, header=None).values.flatten().tolist()
perturb_gene_ids = query_gene_ids[:10_000]

assay = "10x 3' v3"
suspension_type = "cell"
prompt_metadata_dict = {
    "cell_type": "CD8-positive, alpha-beta T cell",
    "tissue": "blood"
}
total_mrna_umis = 5_000


chunk_size = 20

def yield_chunks(lst, n):
    for i in range(0, len(lst), n):
        yield lst[i:i + n]

query_gene_ids_chunks = list(yield_chunks(query_gene_ids, chunk_size))

gene_marginal_means_nq_chunks = []
gene_marginal_std_nq_chunks = []

with torch.inference_mode():

    for query_gene_ids_chunk in tqdm(query_gene_ids_chunks):

        tokens_dict, context_indices, _, _ = ctx.generate_gene_tokens_by_metadata(
            assay=assay,
            suspension_type=suspension_type,
            prompt_metadata_dict=prompt_metadata_dict,
            total_mrna_umis=total_mrna_umis,
            query_gene_ids=query_gene_ids_chunk,
            perturb_gene_ids=perturb_gene_ids
        )

        gene_marginal_means_nq, gene_marginal_std_nq = \
            ctx.get_marginal_mean_std_from_tokens(tokens_dict, context_indices)

        gene_marginal_means_nq_chunks.append(gene_marginal_means_nq.cpu().numpy())
        gene_marginal_std_nq_chunks.append(gene_marginal_std_nq.cpu().numpy())

In [None]:
# write a python function that replaces whitespace and '-' with underscores and removes commas
def clean_string(s):
    return s.replace(" ", "_").replace("-", "_").replace(",", "").replace("'", "p")

filename = (
    f"in_silico_del_response_matrix__{clean_string(prompt_metadata_dict['cell_type'])}__"
    f"{clean_string(prompt_metadata_dict['tissue'])}__"
    f"{clean_string(assay)}__"
    f"{clean_string(suspension_type)}__"
    f"{total_mrna_umis}"
)

gene_marginal_means_nq = np.concatenate(gene_marginal_means_nq_chunks, axis=1)
gene_marginal_std_nq = np.concatenate(gene_marginal_std_nq_chunks, axis=1)

output = {
    "perturb_gene_ids": perturb_gene_ids,
    "query_gene_ids": query_gene_ids,
    "gene_marginal_means_nq": gene_marginal_means_nq,
    "gene_marginal_std_nq": gene_marginal_std_nq,
    "assay": assay,
    "suspension_type": suspension_type,
    "prompt_metadata_dict": prompt_metadata_dict,
    "total_mrna_umis": total_mrna_umis
}

import pickle

with open(f"./output/in_silico_del_via_meta_prompting/{filename}.pkl", "wb") as f:
    pickle.dump(output, f)

### Explore _in silico_ deletion

In [None]:
import pickle

output_pkl_path = os.path.join(
    ROOT_PATH, "cellariumgpt_playground", "output", "in_silico_del_via_meta_prompting",
    "in_silico_del_response_matrix__CD8_positive_alpha_beta_T_cell__blood__10x_3p_v3__cell__5000.pkl"
)
output_dict = pickle.load(open(output_pkl_path, "rb"))
# I forgot to include those in the pickle initially ...
output_dict["assay"] = "10x 3' v3"
output_dict["suspension_type"] = "cell"
output_dict["prompt_metadata_dict"] = {
    "cell_type": "CD8-positive, alpha-beta T cell",
    "tissue": "blood"
}
output_dict["total_mrna_umis"] = 5_000


# output_pkl_path = os.path.join(
#     ROOT_PATH, "cellariumgpt_playground", "output", "in_silico_del_via_meta_prompting",
#     "in_silico_del_response_matrix__cardiac_muscle_cell__heart__10x_3p_v3__nucleus__10000.pkl"
# )
# output_dict = pickle.load(open(output_pkl_path, "rb"))

# # I forgot to include those in the pickle initially ...
# output_dict["assay"] = "10x 3' v3"
# output_dict["suspension_type"] = "nucleus"
# output_dict["prompt_metadata_dict"] = {
#     "cell_type": "cardiac muscle cell",
#     "tissue": "heart"
# }
# output_dict["total_mrna_umis"] = 10_000

In [None]:
# Generate an AnnData containing just the metadata
adata_prop = sc.AnnData(
    X=np.zeros((1, 1)),
    obs=pd.DataFrame({
        "cell_type": [output_dict["prompt_metadata_dict"]["cell_type"]],
        "tissue": [output_dict["prompt_metadata_dict"]["tissue"]],
        "assay": [output_dict["assay"]],
        "suspension_type": [output_dict["suspension_type"]],
        "total_mrna_umis": [output_dict["total_mrna_umis"]],
        "disease": "N/A",
        "development_stage": "N/A",
        "sex": "N/A",
    })
)

gene_info_tsv_path = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

raw_response_qp = output_dict["gene_marginal_means_nq"][1:, :].T
prompt_marginal_mean_p = output_dict["gene_marginal_means_nq"][0, :len(output_dict["perturb_gene_ids"])]
prompt_marginal_std_p = output_dict["gene_marginal_std_nq"][0, :len(output_dict["perturb_gene_ids"])]
query_marginal_mean_q = output_dict["gene_marginal_means_nq"][0, :]
query_marginal_std_q = output_dict["gene_marginal_std_nq"][0, :]

# log fold change response
control_cell_library_size = query_marginal_mean_q.sum()
raw_response_qp = (raw_response_qp / raw_response_qp.sum(0)[None, :]) * control_cell_library_size
normalized_response_qp = np.log(raw_response_qp) - np.log(query_marginal_mean_q)[:, None]

network_ctx = GeneNetworkAnalysisBase(
    adata_obs=adata_prop.obs,
    gene_info_tsv_path=gene_info_tsv_path,
    query_var_names=output_dict["query_gene_ids"],
    prompt_var_names=output_dict["perturb_gene_ids"],
    response_qp=normalized_response_qp,
    prompt_marginal_mean_p=prompt_marginal_mean_p,
    prompt_marginal_std_p=prompt_marginal_std_p,
    query_marginal_mean_q=query_marginal_mean_q,
    query_marginal_std_q=query_marginal_std_q,
    verbose=True
)

In [None]:
# TODO: do we need prompt_z_score or query_z_score?
network_ctx.process(
    response_normalization_strategy="none",
    feature_normalization_strategy="l2",
    query_response_amp_min_pct=0,
    min_prompt_gene_tpm=0,
    min_query_gene_tpm=0)

In [None]:
network_ctx.compute_adjacency_matrix(
    adjacency_strategy="shifted_correlation",
    n_neighbors=10,
    beta=6.,
    self_loop=False)

In [None]:
network_ctx.compute_leiden_communites(
    resolution=5.0)

In [None]:
len(np.unique(network_ctx.leiden_membership))

In [None]:
network_ctx.compute_spectral_dimension(n_lambda_for_estimation=10)

In [None]:
fig, ax = plt.subplots()

network_ctx.plot_spectral_dimension(ax=ax)

#### Embedding

In [None]:
import pymde

network_ctx.make_mde_embedding(
    n_neighbors=10,
    # repulsive_penalty=pymde.penalties.Log,
    init="quadratic",
    device="cuda")

In [None]:
snap_n_gene_symbols = [
    'GAP43',
    'NRXN3',
    'HOMER1',
    'IL1RAPL2',
    'EPHA3',
    'RIMS1',
    'SV2B',
    'TRIM9',
    'SVOP',
    'RPH3A',
    'SYT12',
    'SYT1',
    'R3HDM2',
    'PDE4B',
    'DCC',
    'SLC4A10',
    'DNM3',
    'GRM1',
    'EGR4',
    'JUNB',
    'TFDP2'
]

snap_n_gene_symbols = [x for x in snap_n_gene_symbols if x in network_ctx.query_gene_symbols]
snap_n_gene_ids = [network_ctx.gene_symbol_to_gene_id_map[x] for x in snap_n_gene_symbols]

muscle_gene_symbols = [
    'TTN',
    'MYL3',
    'MYL4',
    'MYL7',
    'TNNC1',
    'TNNI1',
]

muscle_gene_symbols = [x for x in muscle_gene_symbols if x in network_ctx.query_gene_symbols]
muscle_gene_ids = [network_ctx.gene_symbol_to_gene_id_map[x] for x in muscle_gene_symbols]

def get_gene_familities(network_ctx: GeneNetworkAnalysisBase, prefix_list: list[str]) -> tuple[list[str], list[str]]:
    _gene_symbols = [gene_symbol for prefix in prefix_list for gene_symbol in network_ctx.query_gene_symbols if gene_symbol.startswith(prefix)]
    gene_ids = [network_ctx.gene_symbol_to_gene_id_map[gene_symbol] for gene_symbol in _gene_symbols]
    gene_symbols = [network_ctx.gene_id_to_gene_symbol_map[gene_id] for gene_id in gene_ids]
    return gene_ids, gene_symbols

mito_gene_ids, mito_gene_symbols = get_gene_familities(network_ctx, ["MT-"])
ribo_gene_ids, ribo_gene_symbols = get_gene_familities(network_ctx, ["RPS", "RPL"])
hla_gene_ids, hla_gene_symbols = get_gene_familities(network_ctx, ["HLA"])
ifi_gene_ids, ifi_gene_symbols = get_gene_familities(network_ctx, ["IFI"])

highlight_gene_sets = {
    "Mito": (mito_gene_ids, mito_gene_symbols, 'red'),
    "Ribo": (ribo_gene_ids, ribo_gene_symbols, 'blue'),
    # "SNAP-n": (snap_n_gene_ids, snap_n_gene_symbols, 'green'),
    # "HLA": (hla_gene_ids, hla_gene_symbols, 'green'),
    # "IFI": (ifi_gene_ids, ifi_gene_symbols, 'orange'),
    "Muscle": (muscle_gene_ids, muscle_gene_symbols, 'purple'),
}

# disable
# highlight_gene_sets = None

In [None]:
mito_gene_symbols

In [None]:
network_ctx.query_marginal_mean_q[[network_ctx.query_gene_id_to_idx_map[gene_id] for gene_id in mito_gene_ids]]

In [None]:
network_ctx.plot_mde_embedding(highlight_gene_sets=highlight_gene_sets)