In [None]:
import os
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

from cellarium.ml import CellariumModule, CellariumPipeline

import torch._dynamo
torch._dynamo.config.suppress_errors = True

DEVICE = torch.device('cuda:6')
sc.set_figure_params(figsize=(4, 4))

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"
CHECKPOINTS_PATH = "/mnt/cellariumgpt-xfer/100M_long_run/run_001/lightning_logs/version_0/checkpoints"

In [None]:
# Load an AnnData extract
adata_path = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
adata = sc.read_h5ad(adata_path)

In [None]:
gene_ontology_infos = dict()

ref_obs = adata.obs

gene_ontology_infos["assay_ontology_term_id"] = dict()
gene_ontology_infos["assay_ontology_term_id"]["names"] = list(ref_obs['assay_ontology_term_id'].cat.categories)  
gene_ontology_infos["assay_ontology_term_id"]["labels"] = list(ref_obs['assay_ontology_term_id'].cat.categories) # just because I am lazy

gene_ontology_infos["suspension_type"] = dict()
gene_ontology_infos["suspension_type"]["names"] = list(ref_obs['suspension_type'].cat.categories)  # for uniformity -- this variable does not have an ontology (does it?)
gene_ontology_infos["suspension_type"]["labels"] = list(ref_obs['suspension_type'].cat.categories)

In [None]:
# gene IDs, gene symbols, useful maps
model_var_names = np.asarray(adata.var_names)
model_var_names_set = set(model_var_names)
var_name_to_index_map = {var_name: i for i, var_name in enumerate(model_var_names)}

gene_info_tsv_path = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")
gene_info_df = pd.read_csv(gene_info_tsv_path, sep="\t")

gene_symbol_to_gene_id_map = dict()
for gene_symbol, gene_id in zip(gene_info_df['Gene Symbol'], gene_info_df['ENSEMBL Gene ID']):
    if gene_symbol != float('nan'):
        gene_symbol_to_gene_id_map[gene_symbol] = gene_id

gene_id_to_gene_symbol_map = {gene_id: gene_symbol for gene_symbol, gene_id in gene_symbol_to_gene_id_map.items()}
for gene_id in model_var_names:
    if gene_id not in gene_id_to_gene_symbol_map:
        gene_id_to_gene_symbol_map[gene_id] = gene_id

In [None]:
validation_df = pd.read_csv(os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "manifest.csv"), index_col=0)

# reset index to go from 1 to ...
validation_df.index += 1

with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(validation_df)


#### Process a given validation AnnData

In [None]:
from tqdm.notebook import tqdm

# val_idx_list = [58, 66, 69, 53, 67, 92, 107, 108]
val_idx_list = [92, 93, 100, 104, 40, 52, 50, 79]
N_TOP_HVG = 5000

for val_idx in tqdm(val_idx_list):

    val_adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}.h5ad")
    val_adata = sc.read_h5ad(val_adata_path)

    val_adata.layers['counts'] = val_adata.X.copy()

    # normalize
    sc.pp.normalize_total(val_adata, target_sum=1e4)
    sc.pp.log1p(val_adata)

    # calculate mean log TPKC
    val_adata.var['mean_log_tpkc'] = np.asarray(val_adata.X.mean(axis=0)).flatten()

    # make an embedding
    sc.pp.highly_variable_genes(val_adata, n_top_genes=N_TOP_HVG)
    val_adata = val_adata[:, val_adata.var['highly_variable']]

    sc.pp.scale(val_adata, max_value=10)
    sc.pp.pca(val_adata, n_comps=50)
    sc.pp.neighbors(val_adata, n_pcs=50, n_neighbors=30)
    sc.tl.umap(val_adata)

    val_adata.write_h5ad(os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}__processed.h5ad"))

#### Pool all the HVGs together

- For duplicates, choose the higher TPKK
- Sort by TPKK and choose:
  - 5000 genes for prompting
  - 15000 genes for querying

In [None]:
from tqdm.notebook import tqdm

val_idx_list = [92, 93, 100, 104, 40, 52, 50, 79]
N_PROMPT_GENES = 5000

hvg_dict = dict()
hvg_dict['dataset_index'] = []
hvg_dict['gene_id'] = []
hvg_dict['log_tpkc'] = []

for val_idx in tqdm(val_idx_list):

    adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}__processed.h5ad")
    adata = sc.read_h5ad(adata_path)

    n_hvg = adata.shape[1]
    hvg_dict['dataset_index'] += [val_idx] * n_hvg
    hvg_dict['gene_id'] += adata.var.index.tolist()
    hvg_dict['log_tpkc'] += adata.var['mean_log_tpkc'].tolist()

hvg_df = pd.DataFrame(hvg_dict)

# now, group by gene_id, and choose the row with highest log_tpkc;  drop other rows
hvg_df = hvg_df.sort_values('log_tpkc', ascending=False).groupby('gene_id').first().reset_index()

In [None]:
hvg_df['gene_symbol'] = hvg_df['gene_id'].map(gene_id_to_gene_symbol_map)
hvg_df = hvg_df.sort_values('log_tpkc', ascending=False)

In [None]:
prompt_gene_symbols = hvg_df['gene_symbol'].tolist()[:N_PROMPT_GENES]
prompt_gene_ids = hvg_df['gene_id'].tolist()[:N_PROMPT_GENES]

query_gene_symbols = hvg_df['gene_symbol'].tolist()
query_gene_ids = hvg_df['gene_id'].tolist()

In [None]:
plt.hist(hvg_df['log_tpkc'], bins=100, log=True);

In [None]:
# save to disk
hvg_df.to_csv(os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "hvg_df.csv"), index=False)

pd.DataFrame(prompt_gene_ids).to_csv(
    os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "prompt_gene_ids.csv"), index=False, header=False)

pd.DataFrame(query_gene_ids).to_csv(
    os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "query_gene_ids.csv"), index=False, header=False)

pd.DataFrame(prompt_gene_symbols).to_csv(
    os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "prompt_gene_symbols.csv"), index=False, header=False)

pd.DataFrame(query_gene_symbols).to_csv(
    os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "query_gene_symbols.csv"), index=False, header=False)

### Making metacells

In [None]:
from tqdm.notebook import tqdm

val_idx_list = [92, 93, 100, 104, 40, 52, 50, 79]
N_PROMPT_GENES = 5000

MIN_CELLS_PER_TYPE = 10
MAX_CELLS_FOR_METACELL = 100


hvg_dict = dict()
hvg_dict['dataset_index'] = []
hvg_dict['gene_id'] = []
hvg_dict['log_tpkc'] = []

for val_idx in tqdm(val_idx_list):

    adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}__processed.h5ad")
    orig_adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}.h5ad")
    adata = sc.read_h5ad(adata_path)
    orig_adata = sc.read_h5ad(orig_adata_path)

    # how many cell types with as many cells above MIN_CELLS_PER_TYPE?
    cell_type_count_dict = adata.obs['cell_type'].value_counts().to_dict()
    filtered_cell_type_count_dict = dict()

    for cell_type, count in cell_type_count_dict.items():
        if count < MIN_CELLS_PER_TYPE:
            print(f"Skipping cell type {cell_type} in {val_idx} because cell type {cell_type} has only {count} cells")
        else:
            filtered_cell_type_count_dict[cell_type] = count
    
    # for each cell type, select the cell with median total_mrna_umis, and then select up to MAX_CELLS_FOR_METACELL
    for i_cell_type, cell_type in enumerate(filtered_cell_type_count_dict.keys()):
        cell_type_adata = adata[adata.obs['cell_type'] == cell_type]
        cell_type_adata = cell_type_adata[cell_type_adata.obs['total_mrna_umis'].sort_values().index]
        median_idx = cell_type_adata.shape[0] // 2
        median_cell_id = cell_type_adata.obs['original_cell_id'].iloc[median_idx]
        median_pc_k = cell_type_adata.obsm['X_pca'][median_idx, :]

        # sort other cells based on cosine distance to median_pc_k
        median_pc_unit_k = median_pc_k / np.linalg.norm(median_pc_k)
        other_cells_unit_nk = cell_type_adata.obsm['X_pca'] / np.linalg.norm(cell_type_adata.obsm['X_pca'], axis=1)[:, None]
        cell_type_adata.obs['distance_to_median'] = np.linalg.norm(other_cells_unit_nk - median_pc_unit_k, axis=1)
        cell_type_adata = cell_type_adata[cell_type_adata.obs['distance_to_median'].sort_values().index]
        assert cell_type_adata.obs['original_cell_id'].iloc[0] == median_cell_id
        
        # select up to MAX_CELLS_FOR_METACELL
        cell_type_adata = cell_type_adata[:min(MAX_CELLS_FOR_METACELL, cell_type_adata.shape[0])]
        orig_cell_type_adata = orig_adata[orig_adata.obs['original_cell_id'].isin(cell_type_adata.obs['original_cell_id'])]
        assert len(orig_cell_type_adata) == len(cell_type_adata)

        # write to disk
        orig_cell_type_adata.write_h5ad(
            os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "metacells", f"extract_{val_idx}__{i_cell_type}.h5ad"))