### Metadata evaluation

In [None]:
import os
import torch
import numpy as np
import scanpy as sc
import pandas as pd
import matplotlib.pyplot as plt

# for flex attention
import torch._dynamo
torch._dynamo.config.suppress_errors = True

DEVICE = torch.device('cuda:0')
sc.set_figure_params(figsize=(4, 4))

from cellarium.ml.utilities.inference.cellarium_gpt_inference import CellariumGPTInferenceContext

In [None]:
ROOT_PATH = "/mnt/cellariumgpt-xfer/mb-ml-dev-vm"
CHECKPOINT_PATH = "/mnt/cellariumgpt-xfer/100M_long_run/run_001/lightning_logs/version_1/checkpoints/epoch=2-step=252858.ckpt"
REF_ADATA_PATH = os.path.join(ROOT_PATH, "data", "extract_0.h5ad")
GENE_INFO_PATH = os.path.join(ROOT_PATH, "gene_info", "gene_info.tsv")

ctx = CellariumGPTInferenceContext(
    cellarium_gpt_ckpt_path=CHECKPOINT_PATH,
    ref_adata_path=REF_ADATA_PATH,
    gene_info_tsv_path=GENE_INFO_PATH,
    device=DEVICE,
    attention_backend="mem_efficient"
)

### LuCA

#### Process the AnnData

- Subset to LUAD
- Subset to highly variable genes

In [None]:
adata_path = os.path.join(ROOT_PATH, "data", "luca", "5d57179e-17d8-416f-aa55-9c3dbc3c29fc.h5ad")
adata = sc.read_h5ad(adata_path)

In [None]:
# revert to raw counts
adata.X = adata.layers['count'].copy()
adata.obs['total_mrna_umis'] = np.asarray(adata.X.sum(axis=1)).flatten()

# remove unwanted assays
included_assays = ["10x 3' v2", "10x 3' v3"]
included_tissues = ["lung"]
adata = adata[adata.obs['assay'].isin(included_assays) & adata.obs['tissue'].isin(included_tissues)]

# free up some memory
del adata.layers['counts_length_scaled']

In [None]:
adata

In [None]:
sc.pl.umap(adata[adata.obs['disease'] == 'lung adenocarcinoma'], color='cell_type_tumor', gene_symbols='feature_name', vmin=0, vmax=2)

In [None]:
adata = adata[adata.obs['disease'] == 'lung adenocarcinoma']

In [None]:
# sc.pp.normalize_total(adata, target_sum=1e4)
# sc.pp.log1p(adata)

# N_TOP_GENES = 10_000

# sc.pp.highly_variable_genes(adata, n_top_genes=N_TOP_GENES, flavor='seurat_v3', n_bins=20)
# sc.pl.highly_variable_genes(adata)

In [None]:
# subset to a smaller number of cells for testing
N_RAND_CELLS = 1_000

rng = np.random.default_rng(42)
adata_rand = adata[rng.choice(len(adata), N_RAND_CELLS, replace=False)]
adata_rand = adata_rand.copy()
adata_rand.X = adata_rand.layers['count'].copy()

In [None]:
adata_rand.write_h5ad(
    os.path.join(ROOT_PATH, "data", "luca", "5d57179e-17d8-416f-aa55-9c3dbc3c29fc__processed.h5ad"))

#### Load the processed LuCA AnnData and make predictions

In [None]:
adata_path = os.path.join(ROOT_PATH, "data", "luca", "5d57179e-17d8-416f-aa55-9c3dbc3c29fc__processed.h5ad")
adata = sc.read_h5ad(adata_path)

# remove genes that we don't have in the vocabulary
adata_var_names = adata.var_names
adata_var_names_in_model_mask = [var_name in ctx.model_var_names_set for var_name in adata_var_names]
adata = adata[:, adata_var_names_in_model_mask]

# subset genes
N_RAND_GENES = 10_000
rng = np.random.default_rng(42)
adata = adata[:, rng.choice(len(adata.var), N_RAND_GENES, replace=False)]

In [None]:
torch.cuda.empty_cache()
metadata_prediction_dict = ctx.predict_metadata_chunked(adata, chunk_size=32)

# save
torch.save(
    metadata_prediction_dict,
    os.path.join(
        ROOT_PATH, "cellariumgpt_playground", "output",
        "5d57179e-17d8-416f-aa55-9c3dbc3c29fc__metadata_predictions.pt"))

#### Load processed AnnData file and results

In [None]:
adata_path = os.path.join(ROOT_PATH, "data", "luca", "5d57179e-17d8-416f-aa55-9c3dbc3c29fc__processed.h5ad")
adata = sc.read_h5ad(adata_path)

metadata_prediction_dict = torch.load(
    os.path.join(ROOT_PATH, "cellariumgpt_playground", "output", "5d57179e-17d8-416f-aa55-9c3dbc3c29fc__metadata_predictions.pt"))

In [None]:
best_cell_type_indices = np.argmax(metadata_prediction_dict['cell_type'], -1)
best_cell_type_labels = [ctx.metadata_ontology_infos['cell_type']['labels'][i] for i in best_cell_type_indices]

adata.obs['cellariumgpt_cell_type'] = best_cell_type_labels

In [None]:
sc.pl.umap(adata)

In [None]:
sc.pl.umap(adata, color='cellariumgpt_cell_type')

In [None]:
sc.pl.umap(adata, color='cell_type')

In [None]:
sc.pl.umap(adata, color='cell_type_tumor')

In [None]:
disease_keywords = ['adenocarcinoma', 'COVID-19', 'cardiomyopathy']
for disease_keyword in disease_keywords:
    target_disease_indices = []
    target_disease_labels = []
    for idx, disease_label in enumerate(ctx.metadata_ontology_infos['disease']['labels']):
        if disease_label.find(disease_keyword) != -1:
            target_disease_labels.append(disease_label)
            target_disease_indices.append(idx)
    target_disease_score_n = metadata_prediction_dict['disease'][:, target_disease_indices].sum(-1)
    adata.obs[f'cellariumgpt_{disease_keyword}_score'] = target_disease_score_n

target_disease_labels = 'normal'
target_disease_indices = [0]
target_disease_score_n = metadata_prediction_dict['disease'][:, target_disease_indices].sum(-1)
adata.obs['cellariumgpt_normal_score'] = target_disease_score_n

In [None]:
sc.pl.umap(adata, color='cellariumgpt_normal_score', s=40, cmap='RdBu_r', alpha=0.5)

In [None]:
sc.pl.umap(adata, color='cellariumgpt_adenocarcinoma_score', s=40, cmap='RdBu_r', alpha=0.5)

### Process validation datasets

#### Make a manifest

In [None]:
from tqdm.notebook import tqdm

tissue_list = []
disease_list = []
development_stage_list = []

for val_idx in tqdm(range(1, 111)):
    
    val_adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}.h5ad")
    val_adata = sc.read_h5ad(val_adata_path)
    
    tissue = val_adata.obs['tissue'].iloc[0]
    disease = val_adata.obs['disease'].iloc[0]
    development_stage = val_adata.obs['development_stage'].iloc[0]

    tissue_list.append(tissue)
    disease_list.append(disease)
    development_stage_list.append(development_stage)

In [None]:
validation_df = pd.DataFrame(
    {
        'tissue': tissue_list,
        'disease': disease_list,
        'development_stage': development_stage_list,
    }
)

validation_df.to_csv(os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "manifest.csv"))

In [None]:
validation_df = pd.read_csv(os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", "manifest.csv"), index_col=0)

# reset index to go from 1 to ...
validation_df.index += 1

with pd.option_context('display.max_rows', None, 'display.max_columns', None):
    display(validation_df)


In [None]:
validation_df[validation_df["disease"] == "lung adenocarcinoma"]

In [None]:
validation_df[validation_df["tissue"] == "lung"]

#### Process a given validation AnnData

In [None]:
from tqdm.notebook import tqdm

val_idx_list = [58, 66, 69, 53, 67, 92, 107, 108, 92, 93, 100, 104, 40, 52, 50, 79]
N_TOP_HVG = 5000


for val_idx in tqdm(val_idx_list):

    val_adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}.h5ad")
    val_adata = sc.read_h5ad(val_adata_path)

    val_adata.layers['counts'] = val_adata.X.copy()

    sc.pp.normalize_total(val_adata, target_sum=1e4)
    sc.pp.log1p(val_adata)
    sc.pp.highly_variable_genes(val_adata, n_top_genes=N_TOP_HVG)

    val_adata = val_adata[:, val_adata.var['highly_variable']]

    sc.pp.scale(val_adata, max_value=10)
    sc.pp.pca(val_adata, n_comps=50)
    sc.pp.neighbors(val_adata, n_pcs=50, n_neighbors=30)
    sc.tl.umap(val_adata)

    val_adata.write_h5ad(os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}__processed.h5ad"))

#### Predict metadata

In [None]:
len(val_idx_list)

In [None]:
from tqdm.notebook import tqdm

val_idx_list = [58, 66, 69, 53, 67, 92, 107, 108, 92, 93, 100, 104, 40, 52, 50, 79]

N_CELLS_PER_CALL = 192

for val_idx in tqdm(val_idx_list):

    val_adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}__processed.h5ad")
    val_adata = sc.read_h5ad(val_adata_path)

    # revert to integer counts
    val_adata.X = val_adata.layers['counts'].copy()

    # predict
    metadata_prediction_dict = ctx.predict_metadata_chunked(val_adata, chunk_size=N_CELLS_PER_CALL)

    # save
    torch.save(
        metadata_prediction_dict,
        os.path.join(ROOT_PATH, "cellariumgpt_playground", "output", f"extract_{val_idx}__metadata_predictions.pt"))

#### Visualize

In [None]:
val_idx_list = [58, 66, 69, 53, 67, 92, 107, 108, 92, 93, 100, 104, 40, 52, 50, 79]
val_idx = val_idx_list[6]

# load anndata
adata_path = os.path.join(ROOT_PATH, "data", "cellariumgpt_validation", f"extract_{val_idx}__processed.h5ad")
adata = sc.read_h5ad(adata_path)

# print sample metadata
print(adata.obs['disease'].iloc[0])
print(adata.obs['tissue'].iloc[0])
print(adata.obs['development_stage'].iloc[0])

# load predictions
metadata_prediction_dict = torch.load(
    os.path.join(ROOT_PATH, "cellariumgpt_playground", "output", f"extract_{val_idx}__metadata_predictions.pt"))

# make best top_k predictions
top_k = 5
for key in {"cell_type", "disease", "tissue", "development_stage"}:
    top_k_sort_order = np.argsort(metadata_prediction_dict[key], axis=-1)[:, ::-1]
    for k in range(top_k):
        adata.obs[f"cellariumgpt_{key}_{k}_label"] = [
            ctx.metadata_ontology_infos[key]['labels'][i] for i in top_k_sort_order[:, k]]
        adata.obs[f"cellariumgpt_{key}_{k}_prob"] = metadata_prediction_dict[key][np.arange(len(adata)), top_k_sort_order[:, k]]

In [None]:
disease_keywords = ['adenocarcinoma', 'COVID-19', 'cardiomyopathy']
for disease_keyword in disease_keywords:
    target_disease_indices = []
    target_disease_labels = []
    for idx, disease_label in enumerate(ctx.metadata_ontology_infos['disease']['labels']):
        if disease_label.find(disease_keyword) != -1:
            target_disease_labels.append(disease_label)
            target_disease_indices.append(idx)
    target_disease_score_n = metadata_prediction_dict['disease'][:, target_disease_indices].sum(-1)
    adata.obs[f'cellariumgpt_{disease_keyword}_score'] = target_disease_score_n

target_disease_labels = 'normal'
target_disease_indices = [0]
target_disease_score_n = metadata_prediction_dict['disease'][:, target_disease_indices].sum(-1)
adata.obs['cellariumgpt_normal_score'] = target_disease_score_n

In [None]:
import plotly.express as px
import colorcet as cc
import pandas as pd


def generate_interactive_plotly_for_metadata(
        adata: sc.AnnData,
        top_k: int,
        metadata_key: str,
        use_continuous: bool = False,
        value_key: str = None,
        vmin: float = None,
        vmax: float = None,
        width: int = 800,
        height: int = 600,
        markersize: int = 4,
    ):
    
    # Prepare data for plotting
    umap_df = pd.DataFrame({
        'UMAP_1': adata.obsm['X_umap'][:, 0],
        'UMAP_2': adata.obsm['X_umap'][:, 1]
    })
    
    if use_continuous and value_key:
        umap_df['value'] = adata.obs[value_key].values
        color = 'value'
        color_continuous_scale = 'RdBu_r'
    else:
        # Extract unique labels
        labels = adata.obs[f"cellariumgpt_{metadata_key}_0_label"].unique()

        # Assign colors to labels
        colormap = {label: cc.glasbey[i % len(cc.glasbey)] for i, label in enumerate(labels)}
        
        umap_df['label'] = adata.obs[f"cellariumgpt_{metadata_key}_0_label"].values
        color = 'label'
        color_discrete_map = colormap
    
    # Add hover text
    hover_texts = []
    for i in range(len(umap_df)):
        hover_text = []
        for k in range(top_k):
            label_key = f"cellariumgpt_{metadata_key}_{k}_label"
            prob_key = f"cellariumgpt_{metadata_key}_{k}_prob"
            hover_text.append(f"{adata.obs[label_key].iloc[i]}: {adata.obs[prob_key].iloc[i]:.3f}")
        hover_texts.append("<br>".join(hover_text))
    umap_df['hover_text'] = hover_texts
    
    # Create scatter plot
    if use_continuous and value_key:
        fig = px.scatter(
            umap_df,
            x='UMAP_1',
            y='UMAP_2',
            color=color,
            color_continuous_scale=color_continuous_scale,
            hover_name='hover_text',
            title='UMAP Scatter Plot',
            range_color=[vmin, vmax]
        )
    else:
        fig = px.scatter(
            umap_df,
            x='UMAP_1',
            y='UMAP_2',
            color=color,
            color_discrete_map=color_discrete_map,
            hover_name='hover_text',
        )
    
    # Update layout
    fig.update_layout(
        plot_bgcolor='white',
        xaxis=dict(title='UMAP_1', showgrid=False),
        yaxis=dict(title='UMAP_2', showgrid=False),
        width=width,
        height=height
    )
    
    # Update marker size
    fig.update_traces(marker=dict(size=markersize))
    
    return fig

In [None]:
generate_interactive_plotly_for_metadata(
    adata,
    top_k=5,
    metadata_key='cell_type',
    width=700,
    height=500,
    markersize=3)


In [None]:
generate_interactive_plotly_for_metadata(
    adata,
    top_k=5,
    metadata_key='cell_type',
    use_continuous=True,
    value_key='cellariumgpt_COVID-19_score',
    width=500,
    height=500,
    markersize=3)


In [None]:
sc.pl.umap(adata, color="cellariumgpt_cell_type_0_label")

In [None]:
sc.pl.umap(adata, color="cell_type")

In [None]:
sc.pl.umap(adata, color="cellariumgpt_disease_0_label", alpha=adata.obs["cellariumgpt_disease_0_prob"].values)

In [None]:
sc.pl.umap(adata, color="cellariumgpt_cardiomyopathy_score", vmax=0.3)

In [None]:
sc.pl.umap(adata, color="cellariumgpt_disease_0_label", alpha=adata.obs["cellariumgpt_disease_0_prob"].values)

In [None]:
sc.pl.umap(adata, color="disease")

In [None]:
sc.pl.umap(adata, color="tissue")

In [None]:
sc.pl.umap(adata, color="cellariumgpt_tissue_0_label", alpha=adata.obs["cellariumgpt_tissue_0_prob"].values)