In [1]:
import time
import umap

import torch
import collections
import numpy as np
import pandas as pd
import scanpy as sc

from scipy.sparse import csr_matrix

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
from sklearn import linear_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression

In [3]:
import scvi
import VariationalCPA_adv_attent_v3 as CellCap

Global seed set to 0


[easydl] tensorflow not available!


Read data

In [4]:
##read expression anndata data
adata = sc.read_h5ad('../data/scLevyAll.h5ad')
batch = ['a','b','c','d','e','f','g','h']

In [5]:
##read metadata
metadata = pd.read_csv('../data/scLevyAll_metadata.txt',sep='\t',header=0,index_col=False)
index = [metadata['PREFIX'][i]+'_'+metadata['CELL_BARCODE'][i] for i in range(metadata.shape[0])]
metadata.index = index
metadata = metadata.loc[adata.obs.index]
adata.obs = metadata

In [8]:
##read donor info
cohort_info = pd.read_csv('../data/McleanLevy_Dropulation_Cohort.csv',header=0,index_col=False)

In [9]:
import gc
gc.collect()

0

In [10]:
##downsampling
names = adata.obs.index
celltype = []
condition = []
for i in names:
    s = i.split("_")
    celltype.append(s[0])
    if s[1] in batch:
        condition.append('Control')
    else:
        condition.append(s[1])
    
adata.obs['cell_type']=celltype
adata.obs['condition']=condition
sc.pp.subsample(adata, fraction=0.3)
adata = adata[adata.obs['cell_type']!='Astrocyte']
#adata = adata[adata.obs['cell_type']!='Neuron']
adata = adata[adata.obs['cell_type']!='iPSC']
adata = adata[adata.obs['cell_type']!='NPC']
gc.collect()
adata.layers["counts"] = adata.X.copy()

A few common steps of data preprocessing

In [11]:
adata.var['mt'] = adata.var_names.str.startswith('MT-')  # annotate the group of mitochondrial genes as 'mt'
sc.pp.calculate_qc_metrics(adata, qc_vars=['mt'], percent_top=None, log1p=False, inplace=True)

In [13]:
adata = adata[adata.obs.n_genes_by_counts <= 4500, :]
adata = adata[adata.obs.pct_counts_mt <= 20, :]

In [14]:
sc.pp.normalize_total(adata, target_sum=1e4)
sc.pp.log1p(adata)
#sc.pp.subsample(adata, fraction=0.3)
sc.pp.filter_genes(adata, min_cells=500)
adata.raw = adata

sc.pp.highly_variable_genes(adata, flavor='seurat_v3', layer='counts', n_top_genes=3000, subset=True)

Add donor and clinical info to anndata

In [15]:
cohort = {'Clinical':[],'Sex':[],'Age':[]}
for i in range(adata.X.shape[0]):
    donor = adata.obs['DONOR'][i]
    sub_cohort = cohort_info[cohort_info['Linking Donor ID']==donor]
    if sub_cohort.shape[0]>0:
        cohort['Clinical'].append(sub_cohort['Clinical Diagnosis'].values[0])
        cohort['Sex'].append(sub_cohort['Sex'].values[0])
        cohort['Age'].append(sub_cohort['Age'].values[0])
    else:
        cohort['Clinical'].append(np.nan)
        cohort['Sex'].append(np.nan)
        cohort['Age'].append(np.nan)
        
cohort = pd.DataFrame(cohort)

adata.obs['Clinical'] = cohort['Clinical'].values
adata.obs['Sex'] = cohort['Sex'].values
adata.obs['Age'] = cohort['Age'].values

In [16]:
controls = np.repeat(0,adata.X.shape[0])
controls[adata.obs['condition']=='Control']=1
#controls[adata.obs['condition']=='PBS']=1
#controls[adata.obs['condition']=='DMSO']=1
adata.obs['control']=controls

In [17]:
drugtype = pd.read_csv('../data/LevyDrug_class.csv',header=0,index_col=False)
drugtype

Unnamed: 0,Perturbation,Category,Vehicle
0,A1CYTO,Inflammatory response,PBS
1,ATOR,Cholesterol biosynthesis,DMSO
2,AZT,Oxidative stress,DMSO
3,C1Q,Inflammatory response,PBS
4,CLOZ,Antipsychotic,DMSO
5,DMSO,Vehicle control,DMSO
6,EFA,Cholesterol biosynthesis,DMSO
7,GluN2a,NMDA/excitability,PBS
8,GLUT,NMDA/excitability,PBS
9,H2O2,Oxidative stress,PBS


Prepare inputs for model

This includes:

a. perturbation info

b. label for control group

In [18]:
codes, uniques = pd.factorize(adata.obs['condition'])
uniques = list(uniques)
adata.obs['Condition']=codes
drugY = np.zeros((len(codes),len(uniques)))
for i in range(len(codes)):
    j = codes[i]
    drugY[i,j]=1
drugY[:,uniques.index('Control')]=0
drugY[:,uniques.index('DMSO')]=0
drugY[:,uniques.index('PBS')]=0
drugY = drugY[:,np.sum(drugY,0)>0]

In [19]:
drug_names = uniques.copy()
drug_names.remove('Control')
drug_names.remove('DMSO')
drug_names.remove('PBS')

In [20]:
contY = np.zeros((len(codes),len(uniques)))
for i in range(len(codes)):
    j = codes[i]
    sub_drugtype = drugtype[drugtype['Perturbation']==uniques[j]]
    index = uniques.index(sub_drugtype['Vehicle'].values[0])
    contY[i,index]=1
contY[:,uniques.index('Control')]=0
contY = contY[:,np.sum(contY,0)>0]

In [21]:
target_label = np.zeros((len(codes),len(uniques)))
for i in range(len(codes)):
    j = codes[i]
    target_label[i,j]=1
    sub_drugtype = drugtype[drugtype['Perturbation']==uniques[j]]
    index = uniques.index(sub_drugtype['Vehicle'].values[0])
    target_label[i,index]=1
target_label[:,uniques.index('Control')]=0
target_label = target_label[:,np.sum(target_label,0)>0]

In [23]:
vehicles=[]
for i in range(len(codes)):
    j = codes[i]
    sub_drugtype = drugtype[drugtype['Perturbation']==uniques[j]]
    vehicles.append(sub_drugtype['Vehicle'].values[0])
adata.obs['Vehicle']=vehicles

In [25]:
print(drugY.shape)
print(contY.shape)
print(target_label.shape)
adata.obsm['X_drug']=drugY
adata.obsm['X_cont']=contY
adata.obsm['X_target']=target_label

(130089, 13)
(130089, 2)
(130089, 15)


In [26]:
donor_codes, donor_uniques = pd.factorize(adata.obs['DONOR'])
donor_uniques = list(donor_uniques)
donorY = np.zeros((len(donor_codes),len(donor_uniques)))
for i in range(len(donor_codes)):
    j = donor_codes[i]
    donorY[i,j]=1
adata.obsm['X_donor']=donorY

In [27]:
print(donorY.shape)

(130089, 39)


In [28]:
import collections
collections.Counter(adata.obs['cell_type'])

Counter({'Neuron': 130089})

In [29]:
collections.Counter(adata.obs['condition'])

Counter({'INFa': 5255,
         'DMSO': 26914,
         'CLOZ': 5774,
         'HALO': 5343,
         'AZT': 4987,
         'INFy': 12695,
         'EFA': 3497,
         'Control': 25616,
         'TNFa': 5625,
         'ATOR': 3915,
         'ISRD': 6045,
         'GluN2a': 5843,
         'PBS': 8055,
         'H2O2': 1179,
         'GLUT': 7281,
         'SIM': 2065})

setup data for scvi-tools 

In [None]:
n_prog=5
attention = ['A'+str(i) for i in range(1,n_prog+1)]

In [30]:
VariationalCPA.VariationalCPA.setup_anndata(adata, labels_key='control',
                                            pert_key='Condition',layer="counts",
                                            cond_key='X_drug',cont_key='X_cont',
                                            target_key='X_target',donor_key='X_donor')

[34mINFO    [0m Generating sequential column names                                                  
[34mINFO    [0m Generating sequential column names                                                  
[34mINFO    [0m Generating sequential column names                                                  
[34mINFO    [0m Generating sequential column names                                                  


In [37]:
cpa = VariationalCPA.VariationalCPA(adata, n_latent=20, n_layers=3, n_drug=13,n_control=2,n_target=15,n_donor=39,n_prog=n_prog)

training model

In [None]:
cpa.train(max_epochs=700,batch_size=4096)

GPU available: True, used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 153/700:  22%|██▏       | 152/700 [06:42<23:57,  2.62s/it, loss=-5.1e+06, v_num=1] 

Get latent vector and visualization

In [None]:
gc.collect()

In [None]:
z = cpa.get_latent_embedding(adata)

In [None]:
down_samp = pd.DataFrame(z)
#down_samp = down_samp.sample(frac=0.75)
umaps = umap.UMAP(n_neighbors=10, min_dist=0.1, n_components=2,
                  metric="correlation").fit(down_samp.values)

embedding = umaps.transform(z)
embedding = pd.DataFrame(embedding)
adata.obsm['X_umap']=embedding.iloc[:,:2].values
sc.set_figure_params(scanpy=True, dpi=75, dpi_save=75)
sc.pl.umap(adata, color='cell_type', title='',legend_loc='on data')
sc.pl.umap(adata, color='condition', title='')

In [None]:
adata.obsm['X_basal']=z

In [None]:
##ROC to evaluate basal state
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

In [None]:
adata.obsm['X_latent']=z
fpr=dict()
tpr=dict()
roc_auc=dict()
conditions = list(set(adata[adata.obs['condition']!='Control'].obs['condition'].values))
for c in conditions:
    subad = adata[np.logical_or(adata.obs['condition']=='Control',adata.obs['condition']==c)]
    y, ycode = pd.factorize(subad.obs['condition'])
    y[subad.obs['condition']=='Control']=0
    y[subad.obs['condition']==c]=1
    X = subad.obsm['X_latent']
    random_state = np.random.RandomState(0)

    # shuffle and split training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

    # Learn to predict each class against the other
    classifier = LogisticRegression(random_state=random_state)
    y_score = classifier.fit(X_train, y_train).decision_function(X_test)

    fpr[c], tpr[c], _ = roc_curve(y_test, y_score, pos_label=classifier.classes_[1])
    roc_auc[c] = auc(fpr[c], tpr[c])

In [None]:
conditions.sort()

In [None]:
colors = list(sns.color_palette("Paired"))+list(sns.color_palette("hls", 8))

In [None]:
sc.set_figure_params(scanpy=True, dpi=200, dpi_save=200)
plt.figure()
for i, color in zip(range(len(conditions)), colors):
    plt.plot(
        fpr[conditions[i]],
        tpr[conditions[i]],
        color=color,
        lw=1,
        #linestyle=":",
        label="{0}".format(conditions[i]),
        #label="{0} (AUC = {1:0.2f})".format(conditions[i], roc_auc[conditions[i]]),
    )
plt.grid(False)
plt.plot([0, 1], [0, 1], "k--", lw=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("")
legend = plt.legend(loc="lower right",prop={'size': 4.5})
#legend = plt.legend(bbox_to_anchor=(1.1, 1.05),prop={'size': 5})
legend.get_frame().set_facecolor('none')
plt.show()

In [None]:
for i, color in zip(range(len(conditions)), colors):
    print("{0} (AUC = {1:0.4f})".format(conditions[i], roc_auc[conditions[i]]))

In [None]:
zP,zAttn = cpa.get_pert_embedding(adata)

In [None]:
adata.obsm['X_pert']=zP
adata.obsm['X_attn']=zAttn

In [None]:
cloz = adata[adata.obs['condition']=='SIM']
attn = pd.DataFrame(cloz.obsm['X_attn'])
attn.columns = attention
attn.iloc[0,:]=0
attn.iloc[1,:]=1

for a in attention:
    cloz.obs[a]=attn[a].values
    
sc.set_figure_params(scanpy=True, dpi=50, dpi_save=50)
sc.pl.umap(cloz, color=attention,ncols=5, frameon=False,cmap='PRGn')

In [None]:
subad = adata[np.logical_or(np.logical_or(adata.obs['condition']=='DMSO',
                                          adata.obs['condition']=='Control'),
                            adata.obs['condition']=='SIM')]
scaled_data = sc.pp.scale(subad, copy=True).X
scaled_data[scaled_data>=5]=5
scaled_data[scaled_data<=-5]=-5
subad.layers['scaled'] = scaled_data
control = subad[subad.obs['condition']=='Control']
vehicle = subad[subad.obs['condition']=='DMSO']

In [None]:
cloz = subad[subad.obs['condition']=='SIM']
#sc.pp.filter_genes(cloz, min_cells=int(cloz.X.shape[0]*0.2))
sc.pp.filter_genes(cloz, min_cells=int(cloz.X.shape[0]*0.1))
#sc.pp.filter_genes(cloz, min_cells=300)
sc.pp.subsample(control, n_obs=cloz.X.shape[0])
sc.pp.subsample(vehicle, n_obs=cloz.X.shape[0])

In [97]:
X = cloz.X.todense()
scaler = StandardScaler()
X = scaler.fit_transform(X)
X = pd.DataFrame(X)
X.columns = cloz.var.index

In [98]:
y = attn['A2'].values

In [115]:
#reg = LinearRegression(fit_intercept=False)
#reg = linear_model.Ridge(alpha = 0.5)
#reg = linear_model.RidgeCV(alphas=[0.1, 1.0, 10.0],fit_intercept=False)
reg = linear_model.BayesianRidge(fit_intercept=False)
#reg = linear_model.ARDRegression(fit_intercept=False)
#reg = linear_model.SGDRegressor(fit_intercept=False)
reg.fit(X, y)

In [116]:
gene_weights = pd.DataFrame(reg.coef_)
gene_weights.index = cloz.var.index
geneA=gene_weights[0].sort_values(ascending = False).index[:50].tolist()
geneB=gene_weights[0].sort_values(ascending = True).index[:50].tolist()

In [117]:
expr = X.loc[:,geneA+geneB]
expr['Response']=attn['A2'].values
expr = expr.sort_values(by=['Response'],ascending=False)

In [118]:
scaler = MinMaxScaler((0,1))
exprs = scaler.fit_transform(expr)
exprs = pd.DataFrame(exprs)
exprs.columns = expr.columns

In [None]:
sns.set(rc = {'figure.figsize':(18,6)})
sns.set(font_scale=0.01)
#sns.heatmap(exprs.iloc[:,:-1],cmap='PRGn')
sns.heatmap(exprs,cmap='PRGn')

In [120]:
geneA=gene_weights[0].sort_values(ascending = False).index[:20].tolist()
geneB=gene_weights[0].sort_values(ascending = True).index[:20].tolist()

In [None]:
sc.set_figure_params(scanpy=True, dpi=50, dpi_save=50)
print("CLOZ")
sc.pl.umap(cloz, color=geneA,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print("Control")
sc.pl.umap(control, color=geneA,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print("DMSO")
sc.pl.umap(vehicle, color=geneA,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)

In [None]:
print("CLOZ")
sc.pl.umap(cloz, color=geneB,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print("Control")
sc.pl.umap(control, color=geneB,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print("DMSO")
sc.pl.umap(vehicle, color=geneB,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)

In [None]:
print("CLOZ")
sc.pl.umap(cloz, color=['MVD','FABP3','FDPS','FASN'],#'SCD''LSS',
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print("Control")
sc.pl.umap(control, color=['MVD','FABP3','FDPS','FASN'],
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print("DMSO")
sc.pl.umap(vehicle, color=['MVD','FABP3','FDPS','FASN'],
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)

#print("CLOZ")
#sc.pl.umap(cloz, color=['KLF2','KLF4','SOX5'],
#           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
#print("Control")
#sc.pl.umap(control, color=['KLF2','KLF4','SOX5'],
#           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
#print("DMSO")
#sc.pl.umap(vehicle, color=['KLF2','KLF4','SOX5'],
#           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)

In [88]:
subad = adata[np.logical_or(np.array(vehicles)=='Control',np.array(vehicles)=='DMSO')]
#sc.pp.filter_genes(subad, min_cells=1000)
sc.pp.filter_genes(subad, min_cells=int(subad.X.shape[0]*0.1))
subad.layers['scaled'] = sc.pp.scale(subad, copy=True).X

In [89]:
drug_embedding = cpa.get_pert_loadings()
drug_embedding_a2 = drug_embedding[:,0,:]

weights = cpa.get_loadings()
weights.shape

drug_loading = np.matmul(weights,drug_embedding_a2.T)
drug_loading = pd.DataFrame(drug_loading)
drug_loading.index = adata.var.index
drug_loading.columns = drug_names

smileA = drug_loading['SIM'].sort_values(ascending = False)
gene_list = [i for i in smileA.index[:100].tolist() if i in subad.var.index]
new_gene_list = [g for g in gene_list if g in cloz.var.index]
gene_list_A2 = [g for g in new_gene_list if g in subad.var.index]

In [90]:
drug_embedding = cpa.get_pert_loadings()
drug_embedding_a2 = drug_embedding[:,2,:]

weights = cpa.get_loadings()
weights.shape

drug_loading = np.matmul(weights,drug_embedding_a2.T)
drug_loading = pd.DataFrame(drug_loading)
drug_loading.index = adata.var.index
drug_loading.columns = drug_names

smileA = drug_loading['SIM'].sort_values(ascending = False)
gene_list = [i for i in smileA.index[:100].tolist() if i in subad.var.index]
new_gene_list = [g for g in gene_list if g in cloz.var.index]
gene_list_A5 = [g for g in new_gene_list if g in subad.var.index]

In [None]:
sc.set_figure_params(scanpy=True, dpi=50, dpi_save=50)
print("CLOZ")
sc.pl.umap(cloz, color=gene_list_A2,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
print('Attention 5')
sc.pl.umap(cloz, color=gene_list_A5,
           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
#print("Control")
#sc.pl.umap(control, color=new_gene_list[:20],
#           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)
#print("DMSO")
#sc.pl.umap(vehicle, color=new_gene_list[:20],
#           layer='scaled',vmin=-5, vmax=5,cmap='PRGn',ncols=5)

In [None]:
sc.pl.matrixplot(subad, gene_list_A2, groupby='condition',
                 cmap='PRGn',layer='scaled', vmin=-1, vmax=1)

sc.pl.matrixplot(subad, gene_list_A5, groupby='condition',
                 cmap='PRGn',layer='scaled', vmin=-1, vmax=1)

In [None]:
sc.set_figure_params(scanpy=True, dpi=75, dpi_save=75)
geneA=gene_weights[0].sort_values(ascending = False).index[:20].tolist()
subad = control.concatenate(cloz,vehicle)
#sc.pl.matrixplot(subad, ['MVD','LSS','FABP3','FDPS','FASN'], groupby='condition',#'LSS'
#                 cmap='PRGn',layer='scaled', vmin=-1, vmax=1)
#sc.pl.matrixplot(subad, ['HES1'], groupby='condition',
#                 cmap='PRGn',layer='scaled', vmin=-0.5, vmax=0.5)
sc.pl.matrixplot(subad, gene_list_A2, groupby='condition',
                 cmap='PRGn',layer='scaled', vmin=-1, vmax=1)
sc.pl.matrixplot(subad, gene_list_A5, groupby='condition',
                 cmap='PRGn',layer='scaled', vmin=-1, vmax=1)
sc.pl.matrixplot(subad, geneA, groupby='condition',
                 cmap='PRGn',layer='scaled', vmin=-1, vmax=1)
#subad = control.concatenate(cloz)
#sc.pl.matrixplot(subad, ['ISG15','MX1','IFIT1','IFI6','STAT1'], groupby='condition',
#                 cmap='PRGn',layer='scaled', vmin=-3, vmax=3)
#sc.pl.matrixplot(subad, geneA, groupby='condition',
#                 cmap='PRGn',layer='scaled', vmin=-3, vmax=3)