### Does CellariumGPT recapitulate empirical mean for a given cell type?

Can we turn this into a quantitative benchmark?

In [None]:
import os
import torch
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

sc.set_figure_params(figsize=(4, 4))

DEVICE = torch.device('cuda:0')

from cellarium.ml.utilities.inference.cellarium_gpt_inference import \
    CellariumGPTInferenceContext, \
    GeneNetworkAnalysisBase

# reset matplotlib params


In [None]:
ROOT_PATH = "/home/mehrtash/data"
# CHECKPOINT_PATH = "/home/mehrtash/data/compute_optimal_checkpoints/epoch=1-step=29161.ckpt"
# CHECKPOINT_PATH = "/home/mehrtash/data/compute_optimal_checkpoints/epoch=1-step=28244.ckpt"
# CHECKPOINT_PATH = "/home/mehrtash/data/compute_optimal_checkpoints/epoch=2-step=43129.ckpt"
# CHECKPOINT_PATH = "/home/mehrtash/data/compute_optimal_checkpoints/epoch=3-step=53770.ckpt"
CHECKPOINT_PATH = "/home/mehrtash/data/compute_optimal_checkpoints/epoch=6-step=63560.ckpt"

REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient",
    verbose=False
)

In [None]:
# load validation cell type table
val_adata = sc.read_h5ad(
    os.path.join(ROOT_PATH, "data", "cellariumgpt_artifacts", "cell_types_for_validation_filtered.h5ad"))

In [None]:
# show all
with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(val_adata.obs)

In [None]:
cell_index = 12
adata = val_adata[cell_index, :].copy()

metadata_prompt_masks_dict, metadata_dict = ctx.process_user_metadata(
    # assay="Smart-seq v4",
    assay=adata.obs['assay'].values[0],
    suspension_type=adata.obs['suspension_type'].values[0],
    prompt_metadata_dict={
        'cell_type': adata.obs['cell_type'].values[0],
        'disease': adata.obs['disease'].values[0],
        'tissue': adata.obs['tissue'].values[0],
        'sex': adata.obs['sex'].values[0],
        # 'development_stage': adata.obs['development_stage'].values[0],
    },
    total_mrna_umis=adata.obs['total_mrna_umis'].values[0])
    # total_mrna_umis=100_000)

obs_df = pd.DataFrame({key: [value] for key, value in metadata_dict.items()})
adata.obs = obs_df

In [None]:
# renormalize counts to total_mrna_umis
adata.X = adata.X * adata.obs['total_mrna_umis'].values[0] / adata.X.sum(axis=1)

In [None]:
adata.var.set_index('feature_id', inplace=True)    

In [None]:
adata.obs

In [None]:
# tpm_q = adata.X[0]
# tpm_q = 1_000_000 * tpm_q / tpm_q.sum()
# keep_q = tpm_q > 10

# adata.var['keep'] = keep_q

In [None]:
# adata = adata[: , adata.var['keep']]

In [None]:
# protein_coding_genes_set = set(
#     ctx.gene_info_df[ctx.gene_info_df["Gene Biotype"] == "protein_coding"]["ENSEMBL Gene ID"].values)

In [None]:
# adata = adata[:, adata.var.index.isin(protein_coding_genes_set)]

In [None]:
adata

In [None]:
from tqdm.auto import tqdm
from more_itertools import chunked

query_var_names = adata.var_names
query_chunk_size = 128

rand_adata = adata.copy()

# highly expressed genes
n_high = 4096 - 5
rand_adata = rand_adata[:, rand_adata.X.sum(axis=0).argsort()[::-1][:n_high]]
# rand_adata = rand_adata[:, np.random.choice(rand_adata.var_names, 4000, replace=False)]
# rand_adata.X = rand_adata.X.astype(int).astype(np.float32)

gene_logits_chunks = []
chunks = list(chunked(query_var_names, query_chunk_size))
for query_var_names_chunk in tqdm(chunks):

    # Tokenize
    tokens_dict, context_indices = ctx.generate_tokens_from_adata(
        adata=rand_adata,
        obs_index=None,
        query_var_names=query_var_names_chunk,
        metadata_prompt_masks_dict={
            "cell_type": True,
            "tissue": True,
            "development_stage": False,
            "disease": True,
            "sex": True,
        }
    )
    
    with torch.inference_mode():
        gene_logits_nqk = ctx.get_gene_value_logits_from_tokens(tokens_dict, context_indices, 2000)
        gene_logits_chunks.append(gene_logits_nqk.cpu().numpy())

gene_logits = np.concatenate(gene_logits_chunks, axis=1)

In [None]:
gene_logits_nqk = torch.tensor(gene_logits, device=DEVICE)
gene_logits_qk = gene_logits_nqk[0]

In [None]:
# max_counts = 2000
# total_prob_mass = 0.5
# symmetric_range_pad = 10

# # first, find the mode of the counts distribution for each gene
# gene_logits_mode_q = torch.argmax(gene_logits_qk, dim=1)

# # symmetric lower and upper counts about the mode for each gene
# x_lo_qm = torch.clamp(
#     gene_logits_mode_q[:, None] - torch.arange(0, max_counts, device=DEVICE)[None, :], min=0)
# x_hi_qm = torch.clamp(
#     gene_logits_mode_q[:, None] + torch.arange(0, max_counts, device=DEVICE)[None, :], max=max_counts - 1)

# # compute the CDF of counts for each gene
# pdf_qk = gene_logits_qk.exp()
# cdf_qk = pdf_qk.cumsum(dim=1)
# q_indices = torch.arange(cdf_qk.size(0), device=DEVICE)
# symm_prob_mass_qm = (
#     cdf_qk[q_indices[:, None], x_hi_qm]  # add total prob mass at the right point (inclusive)
#     - cdf_qk[q_indices[:, None], x_lo_qm]  # subtract total prob mass at the left point (inclusive)
#     + pdf_qk[q_indices[:, None], x_lo_qm]  # add back the prob mass of the left point
# )
# mask_qm = symm_prob_mass_qm > total_prob_mass
# range_q = torch.clamp(mask_qm.float().argmax(dim=-1) + symmetric_range_pad, max=max_counts - 1)
# x_lo_q = x_lo_qm[q_indices, range_q]
# x_hi_q = x_hi_qm[q_indices, range_q]

In [None]:
# fixed_gene_logits_qk = gene_logits_qk.clone()
# kill_mask_qk = torch.zeros_like(gene_logits_qk, dtype=torch.bool)
# counts_qk = torch.arange(max_counts, device=DEVICE)[None, :].expand(gene_logits_qk.size(0), -1)
# kill_mask_qk[counts_qk > x_hi_q[:, None]] = True
# kill_mask_qk[counts_qk < x_lo_q[:, None]] = True

In [None]:
# id = np.where((adata.var['feature_name'] == "LINC00486"))[0].item()

In [None]:
# gene_logits_qk[kill_mask_qk] = -1000000
# gene_logits_qk = gene_logits_qk - torch.logsumexp(gene_logits_qk, dim=-1, keepdim=True)

In [None]:
gene_marginal_mean_nq, gene_marginal_std_nq = ctx.calculate_gene_mean_std_from_logits(
    gene_logits_qk[None, :, :],
    max_counts=2000,
    use_logsumexp=False
)

In [None]:
expr_q = gene_marginal_mean_nq[0].cpu().numpy()
var = adata.var.copy()
var['expr'] = expr_q

In [None]:
var.sort_values('expr', ascending=False, inplace=True)

In [None]:
var.head(50)

In [None]:
# renormalies
model_mean_q = gene_marginal_mean_nq.cpu().numpy().flatten()
model_mean_q = model_mean_q * adata.obs['total_mrna_umis'].values[0] / model_mean_q.sum()

plt.scatter(np.log1p(adata.X[0]), np.log1p(model_mean_q), s=1)
plt.plot([0, 6], [0, 6], color='red')

In [None]:
gene_index = np.where((adata.var['feature_name'] == "PTPRC"))[0].item()
# gene_index = 4045
gene_logit = gene_logits_qk[gene_index, :].cpu().numpy()

plt.plot(np.exp(gene_logit))
plt.xlim((0, 20))
plt.gca().set_xlabel('counts')
plt.gca().set_ylabel('logits')
# plt.ylim((-10, 0))

In [None]:
np.sum(expr_q)

In [None]:
adata.obs

In [None]:
np.sum(np.exp(gene_logit)[:500])

In [None]:
cutoffs = np.arange(10, 1000, 50)
means = []
for cutoff in cutoffs:
    probs = np.exp(gene_logit)[:cutoff]
    probs = probs / probs.sum()
    mean = np.sum(probs * np.arange(0, cutoff))
    means.append(mean)

In [None]:
plt.plot(cutoffs, means)