# Jacobian of the Cellarium GPT data manifold

Stephen Fleming

2024.06.25

The idea here is to fiddle with the inputs and explore how the output embeddings move as a result.  The transformer stack is designed to take an input and create an output somewhere on the data manifold.

Another way to address the same question addressed by noise prompting: what does the model know about the data manifold?

In [None]:
from cellarium.ml.downstream.cellarium_utils import get_pretrained_model_as_pipeline, harmonize_anndata_with_model
from cellarium.ml.core import CellariumPipeline
from cellarium.ml.downstream.gene_set_utils import GeneSetRecords
from cellarium.ml.downstream.noise_prompting import compute_jacobian

In [None]:
import scanpy as sc
import anndata
import torch
import umap
from sklearn.decomposition import PCA
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
%matplotlib inline

sc.set_figure_params(fontsize=14, vector_friendly=True)

In [None]:
%load_ext autoreload
%autoreload 2

# Instantiate pretrained CellariumGPT model

In [None]:
pipeline = get_pretrained_model_as_pipeline(device="cuda" if torch.cuda.is_available() else "cpu")

# Data

In [None]:
# load data
# adata = anndata.read_h5ad("/home/sfleming/cellarium-ml/notebooks/cell_selection/random_cells/cell__capillary_endothelial_cell.h5ad")
# adata = anndata.read_h5ad("/home/sfleming/cellarium-ml/notebooks/cell_selection/random_cells/cell__inhibitory_interneuron.h5ad")
adata = anndata.read_h5ad("/home/sfleming/cellarium-ml/notebooks/cell_selection/random_cells/cell__erythrocyte.h5ad")
# adata = anndata.read_h5ad("/home/sfleming/cellarium-ml/notebooks/cell_selection/manually_chosen_cardiac_muscle_cell.h5ad")
adata

In [None]:
adata_cell = harmonize_anndata_with_model(adata, pipeline)
adata_cell.var['gpt_include'] = adata.var['gpt_include']
adata_cell.layers['count'] = adata.layers['count']
adata_cell

In [None]:
adata_cell.obs['cell_type'].item()

In [None]:
adata_cell.var['gpt_include'].sum()

# Jacobian

In [None]:
adata_cell.var['gpt_include'].sum()

In [None]:
jacobian_df = compute_jacobian(
    adata_cell,
    pipeline=pipeline,
    var_key_include_genes='gpt_include',
    summarize='mean',
    layer='count',
    var_key_gene_name='gene_name',
)

In [None]:
jacobian_df

## Explore the Jacobian

In [None]:
adata_cell.var['expr'] = np.array(adata_cell.layers['count'].mean(axis=0)).squeeze()
adata_cell.var.sort_values(by='expr', ascending=False).head(10)

In [None]:
plt.plot(np.diagonal(jacobian_df.to_numpy(), 0))

In [None]:
colorbar_max = np.percentile(jacobian_df.abs(), 99)

plt.figure(figsize=(20, 20))
plt.imshow(jacobian_df.to_numpy(), aspect='auto', 
           cmap='PiYG', vmin=-colorbar_max, vmax=colorbar_max)
plt.colorbar()
plt.grid(False)
plt.show()

In [None]:
adata_cell.var['jac_abs_sum'] = 0.0
adata_cell.var.loc[adata_cell.var['gpt_include'], 'jac_abs_sum'] = jacobian_df.abs().sum(axis=1).values

In [None]:
adata_cell.var.sort_values(by='expr', ascending=False).head(10)

In [None]:
adata_cell.var.sort_values(by='jac_abs_sum', ascending=False).head(20)

## Gene lookups

In [None]:
gene_of_interest = 'HBB'

jacobian_df[gene_of_interest].abs().sort_values(ascending=False).head(10)

In [None]:
gene_of_interest = 'MT-CO1'

jacobian_df[gene_of_interest].abs().sort_values(ascending=False).head(10)

In [None]:
jacobian_df[gene_of_interest].loc[gene_of_interest]

## Manifold dimension

Rank of the Jacobian

In [None]:
rank = torch.linalg.matrix_rank(torch.from_numpy(jacobian_df.to_numpy()).cuda())
rank

In [None]:
jacobian_df.shape

Hmmmm, seems pretty large

## Eigenvectors

Eigenvectors of the Jacobian are what we're getting at when we do noise prompting.

In [None]:
U, S, V = torch.pca_lowrank(
    torch.from_numpy(jacobian_df.transpose().to_numpy()).cuda(),
    q=rank,
)
eigenvalues = S.square() / (jacobian_df.shape[0] - 1)

In [None]:
frac_variance_explained = (eigenvalues / eigenvalues.sum()).cpu().numpy()

In [None]:
plt.figure(figsize=(12, 4))
plt.subplot(1, 3, 1)
plt.plot(frac_variance_explained)
plt.title('PCA of UMAP of logFC')
plt.ylabel('Explained variance ratio')
plt.xlabel('PC')
plt.subplot(1, 3, 2)
plt.plot(frac_variance_explained)
plt.yscale('log')
plt.title('log')
plt.subplot(1, 3, 3)
plt.plot(np.cumsum(frac_variance_explained))
plt.title('cumulative')
plt.tight_layout()
plt.show()

In [None]:
V.shape

In [None]:
for i in range(5):
    adata_cell.var[f'jac_eigenvalue_{i}'] = 0.0
    adata_cell.var[f'jac_eigenvalue_abs_{i}'] = 0.0
    adata_cell.var.loc[adata_cell.var['gpt_include'], f'jac_eigenvalue_{i}'] = V[:, i].cpu().numpy().squeeze()
    adata_cell.var.loc[adata_cell.var['gpt_include'], f'jac_eigenvalue_abs_{i}'] = V[:, i].abs().cpu().numpy().squeeze()

In [None]:
adata_cell.var.sort_values(by='jac_eigenvalue_abs_0', ascending=False).head(10)

In [None]:
def show_programs(df, column_names):
    for c in column_names:
        df_tmp = df[['gene_name', 'expr', c, c.replace('abs_', '')]].copy()
        df_tmp = df_tmp.sort_values(by=c, ascending=False)
        print(f'\n{c}')
        print(df_tmp.head(10))

    return df

In [None]:
df = show_programs(adata_cell.var, [f'jac_eigenvalue_abs_{i}' for i in range(5)])