# Noise prompting to explore the Cellarium GPT data manifold: gene sets

Stephen Fleming

2024.05.16

The idea here is to fiddle with the inputs and explore how the output embeddings move as a result.  The transformer stack is designed to take an input and create an output somewhere on the data manifold.  It knows the data manifold.  Are there coordinated motions (coordination among genes)?

This is like doing small linear perturbations around manifold fixed points to see what the on-manifold linear response looks like.

In [None]:
from cellarium.ml.downstream.cellarium_utils import get_pretrained_model_as_pipeline, harmonize_anndata_with_model
from cellarium.ml.downstream.gene_set_utils import GeneSetRecords
from cellarium.ml.downstream.noise_prompting import noise_prompt_gene_set_collection

In [None]:
import scanpy as sc
import anndata
import torch
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

sc.set_figure_params(fontsize=14, vector_friendly=True)

# Instantiate pretrained CellariumGPT model

In [None]:
pipeline = get_pretrained_model_as_pipeline(device="cuda" if torch.cuda.is_available() else "cpu")

# Data

In [None]:
# load data
adata = anndata.read_h5ad("/home/sfleming/geneformer/100k.h5ad")
adata = harmonize_anndata_with_model(adata=adata, pipeline=pipeline)
adata.layers['count'] = adata.X.copy()
adata

In [None]:
# choose cells to focus on
cell_logic = (
    (adata.obs['assay'] == "10x 3' v3") 
    & (adata.obs['total_mrna_umis'] > 5000) 
    & (adata.obs['suspension_type'] == 'cell')
)

## Cell type

In [None]:
adata.obs['cell_type'][cell_logic].value_counts().head(20)

In [None]:
adata_subset = adata[cell_logic & (adata.obs['cell_type'].str.contains('monocyte'))].copy()
adata_subset

## Choosing a single cell

We actually only use one exemplar cell for this work, so we probably want to choose it well.

In [None]:
sc.pp.highly_variable_genes(adata_subset, layer='count', n_top_genes=2000, flavor='seurat_v3')

adata_subset.X = adata_subset.layers['count'].copy()
sc.pp.normalize_total(adata_subset)
sc.pp.log1p(adata_subset)
sc.pp.scale(adata_subset, max_value=10.0)
sc.tl.pca(adata_subset, use_highly_variable=True)
adata_subset.obsm['X_pca_minus1'] = adata_subset.obsm['X_pca'][:, 1:].copy()

sc.pp.neighbors(adata_subset, use_rep='X_pca_minus1', n_pcs=3, n_neighbors=15)
sc.tl.umap(adata_subset)

In [None]:
sc.pl.embedding(adata_subset, basis='umap', color='dataset_id')

In [None]:
sc.pl.embedding(adata_subset, basis='umap', show=False, color='cell_type') # na_color='k')
plt.xticks(np.arange(adata_subset.obsm['X_umap'][:, 0].min(), adata_subset.obsm['X_umap'][:, 0].max()),
           rotation=90)
plt.yticks(np.arange(adata_subset.obsm['X_umap'][:, 1].min(), adata_subset.obsm['X_umap'][:, 1].max()))
plt.show()

In [None]:
# use the UMAP to pick a "typical looking" cell
adata_cell = adata_subset[(adata_subset.obsm['X_umap'][:, 0] > 8)
                          & (adata_subset.obsm['X_umap'][:, 0] < 9)
                          & (adata_subset.obsm['X_umap'][:, 1] > 3)
                          & (adata_subset.obsm['X_umap'][:, 1] < 4)][0].copy()
adata_cell

In [None]:
adata_cell.obs['cell_type'].item()

# Gene sets

In [None]:
msigdb = GeneSetRecords('/home/sfleming/geneformer/msigdb.v2023.2.Hs.json')
msigdb

In [None]:
len(msigdb.get_gene_set_names(collection='C2:CP:BIOCARTA'))

In [None]:
len(msigdb.get_gene_set_names(collection='C2:CP:REACTOME'))

In [None]:
len(msigdb.get_gene_set_names(collection='C5:GO:BP'))

# Noise prompting

## Genes to include in model

In [None]:
# mean expression of each gene in cells of this type
adata_cell.var['mean'] = np.array(adata[adata.obs['cell_type'] == adata_cell.obs['cell_type'].item()].X.mean(axis=0)).squeeze()
adata_cell.var['frac_nonzero'] = np.array((adata[adata.obs['cell_type'] == adata_cell.obs['cell_type'].item()].X > 0).mean(axis=0)).squeeze()

adata_cell.var['gpt_include'] = adata_cell.var['frac_nonzero'] > 0.25
adata_cell.var['gpt_include'].sum()

## Compute

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
adata_cell.X = adata_cell.layers['count'].copy()

In [None]:
collection = 'C5:GO:BP'

df = noise_prompt_gene_set_collection(
    adata_cell,
    pipeline=pipeline,
    msigdb=msigdb,
    collection=collection,
    fraction_of_set_to_perturb=0.5, 
    n_random_splits=10,
    n_perturbations=100, 
    perturbation_scale=1.0,
    min_gene_set_length=10,
    max_gene_set_length=200,
    n_pcs_in_output=5,
    gsea_n_perm=1000,
    seed=0,
    n_pcs=50,
    n_ics=10,
)

In [None]:
!mkdir -p /home/sfleming/cellarium-ml/notebooks/outputs

In [None]:
df.to_csv(f'/home/sfleming/cellarium-ml/notebooks/outputs/cd14+cd16-classicalmonocyte_{collection}.csv')

In [None]:
df_grouped = df[
    ['gene_set_name', 'pval_perturbed', 'es_perturbed', 'pval_unperturbed', 'es_unperturbed', 'pc', 'pc_frac_variance_explained']
].groupby(['gene_set_name', 'pc']).mean()
df_grouped

In [None]:
ddf = df_grouped[df_grouped['pval_unperturbed'] < 0.01].reset_index()
ddf

In [None]:
ddf.sort_values(by='es_unperturbed', ascending=False)

In [None]:
ddf[ddf['gene_set_name'].str.startswith('random_controls')]