In [None]:
import os
import torch
import pickle
import warnings
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

from cellarium.ml.utilities.inference.cellarium_gpt_inference import CellariumGPTInferenceContext
from cellarium.ml.utilities.linreg import batch_linear_regression

In [None]:
# global parameters
cuda_device_index = 0
checkpoint_path = "/home/mehrtash/data/100M_long_run/run_001/lightning_logs/version_3/checkpoints/epoch=5-step=504000.ckpt"
ref_adata_path = "/home/mehrtash/data/data/extract_0.h5ad"
gene_info_tsv_path = "/home/mehrtash/data/gene_info/gene_info.tsv"
validation_adata_path = "/home/mehrtash/data/data/cellariumgpt_artifacts/cell_types_for_validation_filtered.h5ad"
output_path = "/home/mehrtash/data/data/linear_response/100M_long_run_last_checkpoint"
cell_index = 0

# gene expression range determination parameters
query_chunk_size = 1_000  # int
total_prob_mass = 0.5  # float
max_counts = 1000 # int
symmetric_range_pad = 1 # int
max_query_genes = None # int or None
total_mrna_umis = None # int or None

# linear response analysis parameters
n_points = 5 # int
query_chunk_size_linear_response = 64 # int

In [None]:
os.makedirs(output_path, exist_ok=True)

In [None]:
# load validation anndata
val_adata = sc.read_h5ad(validation_adata_path)
print(f"Total number of cells in validation anndata: {val_adata.shape[0]}")

assert cell_index < val_adata.shape[0], "cell_index is out of range"
print(f"Selected cell index: {cell_index}")
val_adata_row = val_adata.obs.iloc[cell_index]
print(val_adata_row)

In [None]:
print("Loading CellariumGPT checkpoint ...")
device = torch.device(f'cuda:{cuda_device_index}')

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=checkpoint_path,
    ref_adata_path=ref_adata_path,
    gene_info_tsv_path=gene_info_tsv_path,
    device=device,
    attention_backend="mem_efficient",
    verbose=False
)

In [None]:
all_query_gene_ids = val_adata.var['feature_id'].values
print(f"Total number of query genes from validation AnnData: {len(all_query_gene_ids)}")

if max_query_genes is not None:
    query_gene_ids = all_query_gene_ids[:max_query_genes]
    print(f"Limiting to {max_query_genes} first query genes for linear response analysis.")
else:
    query_gene_ids = all_query_gene_ids
    print(f"Using all {len(all_query_gene_ids)} query genes for linear response analysis.")

In [None]:
assay = val_adata_row.assay
suspension_type = val_adata_row.suspension_type
prompt_metadata_dict = {
    "cell_type": val_adata_row.cell_type,
    "tissue": val_adata_row.tissue,
    "disease": val_adata_row.disease,
    "sex": val_adata_row.sex,
}

if total_mrna_umis is None:
    total_mrna_umis = float(val_adata_row.total_mrna_umis)
    print(f"Using total_mrna_umis from validation AnnData: {total_mrna_umis:.3f}")
else:
    print(f"Overriding total_mrna_umis from validation AnnData: {val_adata_row.total_mrna_umis:3f} with {total_mrna_umis:.3f}")

In [None]:
print("Determining gene expression range ...")
gex_range_dict = ctx.predict_gene_expression_range_for_metadata(
    assay=assay,
    suspension_type=suspension_type,
    prompt_metadata_dict=prompt_metadata_dict,
    total_mrna_umis=total_mrna_umis,
    query_gene_ids=query_gene_ids,
    query_chunk_size=query_chunk_size,
    total_prob_mass=total_prob_mass,
    symmetric_range_pad=symmetric_range_pad,
    max_counts=max_counts,
)

In [None]:
dose_response_dict = ctx.generate_gene_dose_response_for_metadata(
    assay=assay,
    suspension_type=suspension_type,
    prompt_metadata_dict=prompt_metadata_dict,
    total_mrna_umis=total_mrna_umis,
    query_gene_ids=query_gene_ids,
    perturb_gene_ids=query_gene_ids,
    x_lo_p=gex_range_dict['x_lo_q'].cpu().numpy(),
    x_hi_p=gex_range_dict['x_hi_q'].cpu().numpy(),
    n_points=n_points,
    query_chunk_size=query_chunk_size_linear_response,
    max_counts=max_counts,
)

In [None]:
print("Performing linear regression ...")
n_query_vars = len(query_gene_ids)
n_perturb_vars = len(query_gene_ids)

doses_pi = dose_response_dict['doses_pi']
responses_mean_pqi = dose_response_dict['responses_mean_pqi']

slope_qp, intercept_qp, r_squared_qp = batch_linear_regression(
    x_bn=np.repeat(doses_pi[None, :, :], n_query_vars, axis=-3),
    y_bn=responses_mean_pqi.transpose(1, 0, 2)
)

In [None]:
print("Generating output ...")

output_dict = {

    # global parameters
    "checkpoint_path": checkpoint_path,
    "validation_adata_path": validation_adata_path,
    "cell_index": cell_index,

    # gene expression range determination parameters
    "total_prob_mass": total_prob_mass,
    "max_counts": max_counts,
    "symmetric_range_pad": symmetric_range_pad,
    "max_query_genes": max_query_genes,

    # linear response analysis parameters
    "n_points": n_points,

    # prompt metadata
    "assay": assay,
    "suspension_type": suspension_type,
    "total_mrna_umis": total_mrna_umis,
    "prompt_metadata_dict": prompt_metadata_dict,
    "cell_index": cell_index,

    # expression range
    "x_lo_q": gex_range_dict['x_lo_q'].cpu().numpy(),
    "x_hi_q": gex_range_dict['x_hi_q'].cpu().numpy(),
    "range_q": gex_range_dict['range_q'].cpu().numpy(),
    "gene_logits_qk": gex_range_dict['gene_logits_qk'].cpu().numpy(),
    "gene_logits_mode_q": gex_range_dict['gene_logits_mode_q'].cpu().numpy(),
    "gene_marginal_mean_q": gex_range_dict['gene_marginal_mean_q'].cpu().numpy(),
    "gene_marginal_std_q": gex_range_dict['gene_marginal_std_q'].cpu().numpy(),

    # dose response
    "query_gene_ids": query_gene_ids,
    "perturb_gene_ids": query_gene_ids,
    "doses_pi": dose_response_dict['doses_pi'],
    "responses_mean_pqi": dose_response_dict['responses_mean_pqi'],

    # linear regression
    "slope_qp": slope_qp,
    "intercept_qp": intercept_qp,
    "r_squared_qp": r_squared_qp,
}

print("Saving output ...")
output_file_name = os.path.join(output_path, f"linear_response_cell_index_{cell_index}.pkl")

with open(output_file_name, 'wb') as f:
    pickle.dump(output_dict, f)

print("Script finished successfully.")