# Test run of the algorithm

Yang Xu

Stephen Fleming

2022.12.01

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# import our module

from VariationalCPA_adv_attent_v3 import CellCap as Module

In [None]:
import scvi
import scanpy as sc
import anndata
import numpy as np
import pandas as pd
import umap
import os
import gc

import matplotlib.pyplot as plt
%matplotlib inline

In [None]:
# directory where dataset is located

DATA_DIR = 'data'

# Data

In [None]:
adata = anndata.read_h5ad(os.path.join(DATA_DIR, 'sc_levy_annotated_tiny.h5ad'))
adata

In [None]:
pd.crosstab(adata.obs['donor'], adata.obs['perturbation'])

## Limit to subset of genes

In [None]:
adata.var['n_umi'] = np.array(adata.X.sum(axis=0)).squeeze()

In [None]:
plt.semilogy(np.sort(adata.var['n_umi'])[::-1])
plt.xlabel('Gene (sorted)')
plt.ylabel('Total UMI counts')
plt.show()

In [None]:
sc.pp.highly_variable_genes(adata, n_top_genes=10000, flavor='seurat_v3', subset=True, inplace=True)

In [None]:
adata

In [None]:
plt.semilogy(np.sort(adata.var['n_umi'])[::-1])
plt.xlabel('Gene (sorted)')
plt.ylabel('Total UMI counts')
plt.show()

In [None]:
adata = adata[:, adata.var['n_umi'] > 2].copy()
adata

In [None]:
plt.semilogy(np.sort(adata.var['n_umi'])[::-1])
plt.xlabel('Gene (sorted)')
plt.ylabel('Total UMI counts')
plt.show()

# Run tool

## Set up model

In [None]:
from typing import Tuple, List


def get_one_hot(series: pd.Series) -> Tuple[np.ndarray, List[str]]:
    """Given a pandas Series that is categorical, return a 
    one-hot encoding.
    
    Args:
        series: Pandas Series that is of categorical dtype
        
    Returns:
        Numpy array that is a one-hot-encoded matrix,
            shape is [len(series), series.nunique()]
    
    """
    return pd.get_dummies(series).to_numpy(), series.cat.categories.tolist()

In [None]:
# example

get_one_hot(adata.obs['perturbation'])

In [None]:
# Yang's definitions

adata.obs['drug'] = adata.obs['perturbation'].apply(
    lambda s: np.nan if s in ['control', 'DMSO', 'PBS'] else s,
).astype('category')

adata.obs['target'] = adata.obs['perturbation'].apply(
    lambda s: np.nan if s in ['control'] else s,
).astype('category')

adata.obs['control'] = adata.obs['perturbation'].apply(
    lambda s: True if (s == 'control') else False,
).astype('category')

In [None]:
# put relevant one-hot encodings into adata.obsm slots

for key in ['drug', 'target', 'donor', 'control']:
    adata.obsm[f'X_{key}'], _ = get_one_hot(adata.obs[key])

In [None]:
# store count data in a layer

adata.layers['counts'] = adata.X.copy()

In [None]:
adata

In [None]:
# number of transcriptional response programs
n_prog = 5

In [None]:
Module.setup_anndata(
    adata, 
    labels_key='control',
    pert_key='perturbation',
    layer='counts',
    cond_key='X_drug',
    cont_key='X_control',
    target_key='X_target',
    donor_key='X_donor',
)

In [None]:
latent_dim = 20
hidden_layers = 3

cpa = Module(
    adata, 
    n_latent=latent_dim, 
    n_layers=hidden_layers, 
    n_drug=adata.obs['drug'].nunique(),
    n_control=adata.obs['control'].nunique(),
    n_target=adata.obs['target'].nunique(),
    n_donor=adata.obs['donor'].nunique(),
    n_prog=n_prog,
)

## Train

In [None]:
cpa.train(max_epochs=100, batch_size=512)

In [None]:
# training loss

plt.plot(cpa.history['train_loss_epoch']['train_loss_epoch'], label='Training set')
plt.plot(cpa.history['validation_loss']['validation_loss'], label='Validation set')
plt.ylabel('Loss')
plt.xlabel('Epoch')
plt.legend()
plt.show()

# Exploration

## Latent space

### Computation

In [None]:
z = cpa.get_latent_embedding(adata)

In [None]:
adata.obsm['X_basal'] = z

In [None]:
z_pert, z_attn = cpa.get_pert_embedding(adata)

In [None]:
adata.obsm['X_pert'] = z_pert
adata.obsm['X_attn'] = z_attn

In [None]:
adata

In [None]:
for k in ['basal', 'pert', 'attn']:
    print(f'Shape of "X_{k}" is {adata.obsm[f"X_{k}"].shape}')

### Metrics

Vector component values.

In [None]:
for k in ['basal', 'pert', 'attn']:
    plt.figure(figsize=(12, 2))
    plt.subplot(1, 2, 1)
    plt.hist(adata.obsm[f'X_{k}'].flatten(), bins=100, log=False)
    plt.xlabel('Vector component value')
    plt.ylabel('Number of components \n(N times Q total)')
    plt.title(f'Components of {k} vectors')
    plt.subplot(1, 2, 2)
    plt.hist(adata.obsm[f'X_{k}'].flatten(), bins=100, log=True)
    plt.xlabel('Vector component value')
    plt.ylabel('Number of components \n(N times Q total)')
    plt.title(f'Components of {k} vectors')
    plt.tight_layout()
    plt.show()

Vector lengths.

In [None]:
basal_lengths = np.linalg.norm(adata.obsm['X_basal'], axis=1)
pert_lengths = np.linalg.norm(adata.obsm['X_pert'], axis=1)

bins = np.linspace(0, max(basal_lengths.max(), pert_lengths.max()), 50)

plt.figure(figsize=(12, 2))
plt.subplot(1, 2, 1)
plt.hist(basal_lengths, bins=bins, log=False)
plt.xlabel('Vector length')
plt.ylabel('Number of vectors \n(N total)')
plt.title('basal')
plt.subplot(1, 2, 2)
plt.hist(basal_lengths, bins=bins, log=True)
plt.xlabel('Vector length')
plt.ylabel('Number of vectors \n(N total)')
plt.title('basal')
plt.tight_layout()
plt.show()

plt.figure(figsize=(12, 2))
plt.subplot(1, 2, 1)
plt.hist(pert_lengths, bins=bins, log=False)
plt.xlabel('Vector length')
plt.ylabel('Number of vectors \n(N total)')
plt.title('pert')
plt.subplot(1, 2, 2)
plt.hist(pert_lengths, bins=bins, log=True)
plt.xlabel('Vector length')
plt.ylabel('Number of vectors \n(N total)')
plt.title('pert')
plt.tight_layout()
plt.show()

### Visualization

In [None]:
for k in ['basal', 'pert']:
    
    print(f'Constructing UMAP for "{k}"...')

    adata.obsm[f'X_{k}_umap'] = umap.UMAP(
        n_neighbors=10, 
        min_dist=0.1, 
        n_components=2, 
        metric='euclidean',
    ).fit_transform(adata.obsm[f'X_{k}'])
    
print('done.')

In [None]:
for k in ['basal', 'pert']:
    plt.figure(figsize=(12, 3))
    
    plt.subplot(1, 3, 1)
    sc.pl.embedding(adata, basis=f'{k}_umap', na_color='k', show=False, ax=plt.gca())
    plt.title(k)
    
    plt.subplot(1, 3, 2)
    sc.pl.embedding(adata, basis=f'{k}_umap', color='donor', show=False, ax=plt.gca())
    plt.title(k)
    
    plt.subplot(1, 3, 3)
    sc.pl.embedding(adata, basis=f'{k}_umap', color='perturbation', show=False, ax=plt.gca())
    plt.title(k)
    
    plt.tight_layout()
    
    plt.show()