# Run scLAMBDA on the essential gene perturbation datasets.

In [1]:
import pandas as pd
import anndata as ad
import numpy as np
import scanpy as sc
import sclambda

import matplotlib.pyplot as plt

## Load gene embeddings, cell context embeddings and Perturb-seq data

Preprocessed data are available at https://drive.google.com/drive/folders/1kpQOZ0OWnzngmYGGjWyDd_KHog94f6w4?usp=sharing.

In [None]:
adata_processed = sc.read_h5ad('./multi_cell_line_ES.h5ad')
adata_processed.obs = adata_processed.obs.rename(columns={'gene': 'condition', 'celltype': 'cell_type'})
adata_processed.obs['cell_type+condition'] = [adata_processed.obs['cell_type'].values[i] + '_---_' + adata_processed.obs['condition'].values[i] for i in range(adata_processed.shape[0])]
adata_processed.X = adata_processed.X.toarray()

In [None]:
cont_emb = np.load('/gpfs/gibbs/pi/zhao/gw399/project/perturbation/genept/GPT_3_5_cell_line_embeddings_3-large.npy', allow_pickle=True)
cont_emb = cont_emb.item()
cont_emb_ct = {}
cont_emb_ct['rpe1'] = cont_emb['RPE1']
cont_emb_ct['k562'] = cont_emb['K562']
cont_emb_ct['hepg2'] = cont_emb['HepG2']
cont_emb_ct['jurkat'] = cont_emb['Jurkat']

In [None]:
gene_embeddings_genept = np.load('/gpfs/gibbs/pi/zhao/gw399/project/perturbation/genept/GPT_3_5_gene_embeddings_multi_cell_line_3-large.npy', allow_pickle=True)

gene_embeddings = gene_embeddings_genept.item()

In this tutorial code, the model is trained using RPE1, K562 and HepG2 cell lines, and evaluated on the Jurkat cell line.

In [None]:
adata_processed.obs['split'] = 'train'
adata_processed.obs['split'][(adata_processed.obs['cell_type'].values == 'jurkat')] = 'test' # check for generalization

In [None]:
# Run this part of code for training

model = sclambda.model.Model_context(adata_processed, 
                                     gene_embeddings, 
                                     cont_emb_ct,
                                     multi_gene = False,
                                     model_path = "./models_multi_cellline_for_jurkat_eval")
model.train()

In [None]:
# Run this part of code for evaluating on the Jurkat cell line

model = sclambda.model.Model_context(adata_processed, 
                                     gene_embeddings, 
                                     cont_emb_ct,
                                     multi_gene = False,
                                     model_path = "./models_multi_cellline_for_jurkat_eval")
model.load_pretrain()

# Find all testing target genes for evaluation
pert_test = np.unique(adata_processed[adata_processed.obs['cell_type'].values == 'jurkat'].obs['condition'].values)

# For computing W2 distance
import ot
from scipy.spatial.distance import cdist

def w2_distance(samples_1, samples_2):
    a = np.ones(samples_1.shape[0]) / samples_1.shape[0]
    b = np.ones(samples_2.shape[0]) / samples_2.shape[0]
    M = cdist(samples_1, samples_2, 'sqeuclidean')
    w = ot.emd2(a,b,M)
    return np.sqrt(w)

np.random.seed(0)
adata = adata_processed[adata_processed.obs['cell_type'].values == 'jurkat']
gene_weight = np.std(adata[adata.obs['cell_type'].values == 'jurkat'].X, axis=0) > 0 # Only evaluate on the genes measured in the Jurkat data
n_eval = 100 # subset for faster W2 evaluation

corr_sclambda = []
w2_sclambda = []

for i in pert_test:
    res = model.generate('jurkat', i, return_type = 'cells', n_cells = 500)
    print(i)
    if i != 'ctrl':
        pt_cells = np.array(adata[(adata.obs['cell_type'].values=='jurkat') & (adata.obs['condition'].values==i)].X.toarray())
        id_pt = np.random.choice(pt_cells.shape[0], n_eval)
    
        x_hat = res['jurkat_---_'+i]
        id_sclambda = np.random.choice(x_hat.shape[0], n_eval)
        wd_test = w2_distance(x_hat[id_sclambda][:, gene_weight], 
                                                pt_cells[id_pt][:, gene_weight])
        x_hat = np.mean(x_hat, axis=0)
        corr_test = np.corrcoef((np.array(x_hat).reshape(-1) - np.array(model.ctrl_mean['jurkat']).reshape(-1))[gene_weight], 
                                        model.pert_delta['jurkat_---_'+i][gene_weight])
        w2_sclambda.append(wd_test)
        corr_sclambda.append(corr_test[0, 1])

In [None]:
print(np.mean(corr_sclambda), np.mean(w2_sclambda))