# Run data from simulation 1 and make validation plots

Stephen Fleming

2023.06.27

Small test using a tiny simulation

In [None]:
import scanpy as sc
import numpy as np
import pandas as pd
import scvi

import seaborn as sns
import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
from cellcap.scvi_module import CellCap
from cellcap.validation.plot import plot_adversarial_classifier_roc, plot_ard_parameters, plot_program_usage

In [None]:
sc.set_figure_params(scanpy=True, vector_friendly=True)

In [None]:
import torch
torch.cuda.is_available()

# Data

In [None]:
adata = sc.read_h5ad('data/simulation_data1.h5ad')
adata.layers['counts'] = adata.X.copy()
adata

In [None]:
adata.obs.head()

In [None]:
perturbation_key = 'Condition'

## Exploration

In [None]:
pd.crosstab(adata.obs[perturbation_key], adata.obs['batch'], dropna=False, margins=True)

In [None]:
pd.crosstab(adata.obs[perturbation_key], adata.obs['State'], dropna=False, margins=True)

# Model

## Setup

In [None]:
# limit to control cells for testing purposes

adata = adata[adata.obs[perturbation_key] == 'Control'].copy()
adata

In [None]:
# for now we are using adata.obsm slots

assert 'control' in adata.obs[perturbation_key].str.lower().values, \
    f'adata.obs["{perturbation_key}"] does not contain "control" or "Control" '

adata.obsm['X_donor'] = pd.get_dummies(adata.obs['State']).to_numpy().astype(float)
adata.obsm['X_target'] = pd.get_dummies(
    (adata.obs[perturbation_key]
     .str.lower()
     .replace(to_replace='control', value=np.nan))
).to_numpy().astype(float)

adata

In [None]:
adata.obsm['X_target'].sum(axis=0)

In [None]:
adata.obs[perturbation_key].value_counts()

In [None]:
CellCap.setup_anndata(adata, layer='counts', target_key='X_target', donor_key='X_donor')

In [None]:
n_response_programs = 15

cellcap = CellCap(
    adata, 
    n_latent=20,  # 20, 
    n_layers=3,  # 3,
    n_drug=adata.obs[perturbation_key].nunique() - 1, 
    n_donor=adata.obs['State'].nunique(), 
    gene_likelihood='nb', 
    n_prog=n_response_programs,
)

## Train

In [None]:
scvi.settings.seed = 0

In [None]:
cellcap.train(max_epochs=1000, batch_size=256, use_gpu=True, early_stopping=True)

In [None]:
cellcap.history.keys()

In [None]:
pd.concat([cellcap.history[k] for k in cellcap.history.keys() 
           if (k.startswith('b_q_') and k.endswith('_train'))]).plot()
plt.ylabel('Laplace scale parameter')
plt.ylim([0, 1])
plt.legend(loc='center left', bbox_to_anchor=(1, 0.5))
plt.show()

In [None]:
pd.concat([cellcap.history[k] 
           for k in ['train_loss_epoch', 'validation_loss', 'adv_loss_validation']
           if k in cellcap.history.keys()]).plot()
plt.ylabel('loss')
plt.ylim([0, np.quantile(cellcap.history['train_loss_epoch'], q=0.9) * 1.25])
plt.show()

# Output exploration

Get latent vector and visualization

In [None]:
import gc
gc.collect()

## Program usage

In [None]:
# how much is each program used by each perturbation?

df_usage = plot_program_usage(
    cellcap=cellcap,
    adata=adata,
    perturbation_key=perturbation_key,
)

In [None]:
# which programs does ARD "turn off"?

df_ard = plot_ard_parameters(
    cellcap=cellcap,
    adata=adata,
    perturbation_key=perturbation_key,
)

In [None]:
df_ard

## Basal cell state

In [None]:
z = cellcap.get_latent_embedding(adata)
z.shape

In [None]:
adata.obsm['X_basal'] = z

### Classifier accuracy

Want a classifier that uses the basal state to predict perturbation to be very poor.

In [None]:
plot_adversarial_classifier_roc(
    adata=adata,
    perturbation_key=perturbation_key,
)

### UMAP

In [None]:
sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_basal', random_state=0, metric='cosine', method='umap')
sc.tl.umap(adata, min_dist=0.15)

In [None]:
sc.pl.umap(adata, color='Condition', title='', legend_fontsize=7.5)

In [None]:
sc.pl.umap(adata, color='Pseudotime', legend_fontsize=7.5)

## Attention maps