# Simulations notebook

## Simulated data

### Poisson-logNormal $Z_i \sim N(\mu_i, \sigma_i^2)$

Set our latent $Z_i, i=1, \dots, 9$ with $Z_i \sim N(\mu_i, \sigma_i^2)$.

In [4]:
import numpy as np
import pandas as pd
import scipy as sp
import matplotlib.pyplot as plt
import scanpy as sc

In [5]:
print(np.__version__)

1.23.5


In [6]:
K = 20
n = 20000 # number of cells
N = G = 1000 # number of genes
np.random.seed(123)
mu = np.random.uniform(0,1, K)
Z = np.random.normal(mu,0.1) 

Linear model for each cell $i$, $ X_i = BZ_i + \epsilon$ 
ie for $K$ "modules" $Z_1,\dots, Z_K$ 

$X_{i} = \sum_k B_{ik}Z_k + \epsilon_i$, $\epsilon_i \sim N(0, \sigma_0^2 I_{G})$

Non linear transformation for count data: poisson lognormal $X_i \sim \mathcal{P}(e^{BZ_i+\epsilon})$

Here $N = G = 1000$

In [7]:
N = 1000
p = 0.3
def randbin(m, n, p, seed=123):  
    np.random.seed(seed)
    return np.random.choice([0, 1], size=(m,n), p=[p, 1-p])

def non_linear_transform(X, f=None):
    if f == 'exp':
        return np.exp(-X)
    if f == 'tanh':
        return np.tanh(X)
    if f == 'roll':
        return X*np.cos(X)
    else:
        return np.cos(X)
f = 'roll'
B = np.transpose(randbin(K,N, p))
sigma0 = np.var(non_linear_transform(B@Z, f=f))*0.3
eps = np.random.normal(0,sigma0, size=N)
X_nl = non_linear_transform(B@Z, f=f)

In [8]:
def poisson_lognormal(Z, seed=None):
    if seed!=None:
        np.random.seed(seed)
    sigma0 = np.sqrt(np.var(Z))*0.3
    eps = np.random.normal(0,sigma0, size=Z.shape[0])
    lam = np.exp(Z+eps)
    return np.random.poisson(lam)

X = np.array([poisson_lognormal(X_nl) for i in range(n)])

In [None]:
import umap
from sklearn.preprocessing import StandardScaler

In [None]:
reducer = umap.UMAP()
scaled_X = StandardScaler().fit_transform(X)
embedding = reducer.fit_transform(scaled_X)
embedding.shape

In [None]:
B.shape

In [None]:
keys = ['module_{}'.format(i) for i in range(K)]
B_dict = dict(zip(keys, B.T))

In [None]:
plt.scatter(
    embedding[:, 0],
    embedding[:, 1],
    alpha=0.5
    )
plt.gca().set_aspect('equal', 'datalim')
plt.title('UMAP projection of synthetic data', fontsize=12);

#### add interventions

Add interventions vector $d \in \mathbb{R}^G$, in a linear model $X_i(d) = \text{Poisson-lognormal}((B \otimes d) Z_i + \epsilon_i)$ where $d$ such that $d_k = 0$ for blocked genes, $d_k = 1$ for not blocked. 

Recall $B \in \mathbb{R}^{N\times K}$ with $N$ genes and $K$ modules. So intervention blocking gene $g$ should set row $g$ to zero in $B$ ($0_{1 \times K}$ vector).

In [None]:
np.max(X), np.min(X), np.mean(X)

In [None]:
import seaborn as sns
sns.heatmap(np.log(1+X))

Define knocked out genes.

In [None]:
d=np.ones(N)
d[20] = d[10] = d[3] =d[1]= d[15]= 0 # set genes 3, 10 and 100 as knock out. 
B_d = np.multiply(B.T,d).T

In [None]:
(B_d@Z).shape

In [None]:
X_d = np.array([poisson_lognormal(B_d@Z) for i in range(n)])

In [None]:
reducer = umap.UMAP()
scaled_X_d = StandardScaler().fit_transform(X_d)
embedding_d = reducer.fit_transform(scaled_X_d)
embedding_d.shape

In [None]:
plt.figure(figsize=(10,5))
plt.subplot(222).set_title('B matrix with knock outs 3, 10, 20')
plt.imshow(B_d[:50,:].T)
plt.subplot(221).set_title('X with knocked out genes')
sns.heatmap(np.log(1+X_d))

#plt.scatter((B@Z)[:,0],(B@Z)[:,1], color='orange', alpha=1, label='all observed')
#plt.scatter((B_d@Z)[:,0],(B_d@Z)[:,1], color='blue', alpha=0.2, label='with knock outs')


In [None]:
X[:,3],X_d[:,3]

In [None]:
g1 = 0
g2 = 2
plt.figure(figsize=(8,5))
plt.scatter(X[:,g1],X[:,g2], color='red', alpha=0.3, label='observation baseline')
plt.scatter(X_d[:,g1],X_d[:,g2], color='blue', alpha=0.5, label='with knock outs')
plt.xlabel('gene {}'.format(g1))
plt.ylabel('gene {}'.format(g2))
plt.legend()
plt.title('observation space between two genes for unperturbed versus with knock outs')

In [None]:
g1 = 43
g2 = 26
plt.figure(figsize=(8,5))
plt.scatter(X[:,g1],X[:,g2], color='red', alpha=0.3, label='observation baseline')
plt.scatter(X_d[:,g1],X_d[:,g2], color='blue', alpha=0.5, label='with knock outs')
plt.xlabel('gene {}'.format(g1))
plt.ylabel('gene {}'.format(g2))
plt.legend()
plt.title('observation space between two genes for unperturbed versus with knock outs')

Define the perturbations. Assume 10 different perturbations chosen at random. 
Add 10 perturbations that are linear combinations of the 20 perturbations.

In [None]:
# assume 10 different perturbations, chosen at random. 
# will be binary array od 1000 values
M_pert = np.zeros((20, 1000)).astype('int')
M_pert[:10,:] = randbin(10,N, p, seed=999).astype('int')

M_pert.shape

In [None]:
for i in range(5):
    M_pert[10+i,:] = np.bitwise_or(M_pert[i,:],M_pert[i+1,:])

In [None]:
for i in range(5):
    M_pert[15+i,:] = np.bitwise_or(M_pert[i,:],M_pert[9-i,:])

Note that perturbations 0 to 9 are single perturbations, perturbations 10 to 14 are combinations of adjacent perturbations: pert_0+pert_1, pert_1+pert_2, ... and perturbations 15 to 19 are combinations of pert_0+pert_9, pert_1+pert_8, pert_2+pert_7, pert_3+pert_6, pert_4+pert_5.

Extend it to 20 000 observations: dictionary of perturbations.


In [None]:
np.random.seed(999)
pert_keys = np.random.choice([i for i in range(20)],20000, p = [0.2, 0.05, 0.05, 0.12, 0.08, 0.025, 0.0125, 0.0125, 0.075, 0.025, 0.03, 0.04, 0.02, 0.01, 0.005, 0.005, 0.04, 0.1, 0.03, 0.07])
plt.figure(figsize=(7,3))
plt.hist(pert_keys,bins=20)
plt.title("Frequency of perturbation types")
plt.xlabel("Perturbation ID 0 to 19")
plt.ylabel('Frequency')

In [None]:
D_pert = np.array([M_pert[pert_keys[i],:] for i in range(20000)]).reshape(20000,1000)

In [None]:
sns.heatmap(D_pert)

In [None]:
from tqdm import tqdm

In [None]:
X_d = np.empty((20000,1000))
for i in tqdm(range(20000)):
    d = D_pert[i,:]
    X_d[i,:] = poisson_lognormal(np.multiply(B.T,d).T@Z)

In [None]:
plt.figure(figsize=(3,2))
sns.heatmap(X_d[:20,:20])

#### Transform to AnnData and pass through CPA

In [None]:
import anndata as ad

In [None]:
X_d = sp.sparse.csr_matrix(X_d)

In [None]:
adata = ad.AnnData(X_d)

In [None]:
adata.obs_names = [f"Cell_{i:d}" for i in range(adata.n_obs)]
adata.var_names = [f"Gene_{i:d}" for i in range(adata.n_vars)]
print(adata.obs_names[:10])

In [None]:
# perturbation keys dictionary
pert_keys10_19 = {'pert_10':'pert_0+pert_1', 'pert_11':'pert_1+pert_2', 'pert_12': 'pert_2+pert_3', 'pert_13':'pert_3+pert_4', 'pert_14':'pert_4+pert_5',
'pert_15':'pert_0+pert_9', 'pert_16':'pert_1+pert_8', 'pert_17':'pert_2+pert_7', 'pert_18':'pert_3+pert_6','pert_19': 'pert_4+pert_5'}

In [None]:
adata.obs["perturbation_ID"] = pd.Categorical(['pert_' + str(pert_keys[i]) for i in range(len(pert_keys))])
adata.obs.head()

In [None]:
adata.obs['dosage_id'] = adata.obs['perturbation_ID'].astype(str).apply(lambda x: '+'.join(['1.0' for _ in x.split('+')])).values
adata.obs.head()

#### Try on CPA

In [None]:
import cpa

In [None]:
cpa.CPA.setup_anndata(adata, 
                      perturbation_key='perturbation_ID',
                      control_group='pert_0',
                      dosage_key='dosage_id',
                     # categorical_covariate_keys=['cell_type'],
                      is_count_data=True
                     # deg_uns_key='rank_genes_groups_cov',
                     # deg_uns_cat_key='cov_cond',
                     # max_comb_len=2,
                     )

In [None]:
model_params = {
    "n_latent": 32,
    "recon_loss": "nb",
    "doser_type": "linear",
    "n_hidden_encoder": 256,
    "n_layers_encoder": 4,
    "n_hidden_decoder": 256,
    "n_layers_decoder": 2,
    "use_batch_norm_encoder": True,
    "use_layer_norm_encoder": False,
    "use_batch_norm_decoder": False,
    "use_layer_norm_decoder": False,
    "dropout_rate_encoder": 0.2,
    "dropout_rate_decoder": 0.0,
    "variational": False,
    "seed": 8206,
}

trainer_params = {
    "n_epochs_kl_warmup": None,
    "n_epochs_adv_warmup": 50,
    "n_epochs_mixup_warmup": 10,
    "n_epochs_pretrain_ae": 10,
    "mixup_alpha": 0.1,
    "lr": 0.0001,
    "wd": 3.2170178270865573e-06,
    "adv_steps": 3,
    "reg_adv": 10.0,
    "pen_adv": 20.0,
    "adv_lr": 0.0001,
    "adv_wd": 7.051355554517135e-06,
    "n_layers_adv": 2,
    "n_hidden_adv": 128,
    "use_batch_norm_adv": True,
    "use_layer_norm_adv": False,
    "dropout_rate_adv": 0.3,
    "step_size_lr": 25,
    "do_clip_grad": False,
    "adv_loss": "cce",
    "gradient_clip_value": 5.0,
}

In [None]:
adata.obs['split'] = np.random.choice(['train', 'valid'], size=adata.n_obs, p=[0.85, 0.15])
adata.obs.loc[adata.obs['perturbation_ID'].isin(['pert_17', 'pert_13']), 'split'] = 'ood'

In [None]:
adata.obs['split'].value_counts()

In [None]:
model = cpa.CPA(adata=adata, 
                split_key='split',
                train_split='train',
                valid_split='valid',
                test_split='ood',
                **model_params,
               )

In [None]:
!conda env config vars set PYTORCH_ENABLE_MPS_FALLBACK=1

In [None]:
model.train(max_epochs=2000,
            use_gpu=True,
            batch_size=2048,
            plan_kwargs=trainer_params,
            early_stopping_patience=5,
            check_val_every_n_epoch=5,
            save_path='',
           )

In [None]:
import os
os.getcwd()

In [None]:
cpa.pl.plot_history(model)

In [None]:
latent_outputs = model.get_latent_representation(adata, batch_size=2048)

In [None]:
latent_outputs.keys()

In [None]:
sc.pp.neighbors(latent_outputs['latent_basal'])
sc.tl.umap(latent_outputs['latent_basal'])

In [None]:
sc.pl.umap(latent_outputs['latent_basal'], 
           color='cond_harm', 
           groups=groups,
           palette=sc.pl.palettes.godsnot_102,
           frameon=False)

In [None]:
sc.pl.umap(latent_outputs['latent_basal'], 
           color='pathway', 
           palette=sc.pl.palettes.godsnot_102,
           frameon=False)

### Poisson-Gamma mixture $Z_i \sim \Gamma(a_i,b_i)$