In [1]:
import gc
import time
import umap

import torch
import collections
import numpy as np
import pandas as pd
import scanpy as sc

import seaborn as sns
import matplotlib.pyplot as plt

In [2]:
from sklearn import linear_model
from sklearn.preprocessing import MinMaxScaler
from sklearn.preprocessing import StandardScaler
from sklearn.linear_model import LinearRegression

In [3]:
import scvi
from cellcap.scvi_module import CellCap

Global seed set to 0
  new_rank_zero_deprecation(
  return new_rank_zero_deprecation(*args, **kwargs)
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)
  doc = func(self, args[0].__doc__, *args[1:], **kwargs)


In [4]:
def downsample_to_smallest_category(adata,column="cell_type",random_state=None,
                                    min_cells=15,keep_small_categories=False):
    
    counts = adata.obs[column].value_counts(sort=False)
    min_size = min(counts[counts >= min_cells])
    sample_selection = None
    for sample, num_cells in counts.items():
        if num_cells <= min_cells:
            if keep_small_categories:
                sel = adata.obs.index.isin(
                    adata.obs[adata.obs[column] == sample].index)
            else:
                continue
        else:
            sel = adata.obs.index.isin(
                adata.obs[adata.obs[column] == sample]
                .sample(min_size, random_state=random_state)
                .index
            )
        if sample_selection is None:
            sample_selection = sel
        else:
            sample_selection |= sel
    return adata[sample_selection].copy()

Prepare data for ready-to-use

In [5]:
adata = sc.read_h5ad('../data/scLevyAll_neuron20.h5ad')

In [6]:
adata = downsample_to_smallest_category(adata,column="condition",
                                        min_cells=3000,keep_small_categories=True)
collections.Counter(adata.obs['condition'])

Counter({'INFa': 3426,
         'DMSO': 3426,
         'CLOZ': 3426,
         'HALO': 3426,
         'AZT': 3426,
         'EFA': 2452,
         'ATOR': 2629,
         'ISRD': 3426,
         'GluN2a': 3426,
         'Control': 3426,
         'INFy': 3426,
         'PBS': 3426,
         'H2O2': 1046,
         'TNFa': 3426,
         'GLUT': 3426,
         'SIM': 1365})

In [7]:
drugtype = pd.read_csv('../data/LevyDrug_class.csv',header=0,index_col=False)
drugtype

Unnamed: 0,Perturbation,Category,Vehicle
0,A1CYTO,Inflammatory response,PBS
1,ATOR,Cholesterol biosynthesis,DMSO
2,AZT,Oxidative stress,DMSO
3,C1Q,Inflammatory response,PBS
4,CLOZ,Antipsychotic,DMSO
5,DMSO,Vehicle control,DMSO
6,EFA,Cholesterol biosynthesis,DMSO
7,GluN2a,NMDA/excitability,PBS
8,GLUT,NMDA/excitability,PBS
9,H2O2,Oxidative stress,PBS


In [8]:
codes, uniques = pd.factorize(adata.obs['condition'])
uniques = list(uniques)
adata.obs['Condition']=codes
#drugY = np.zeros((len(codes),len(uniques)))
#for i in range(len(codes)):
#    j = codes[i]
#    drugY[i,j]=1
#drugY[:,uniques.index('Control')]=0
#drugY[:,uniques.index('DMSO')]=0
#drugY[:,uniques.index('PBS')]=0
#drugY = drugY[:,np.sum(drugY,0)>0]

In [9]:
drug_names = uniques.copy()
drug_names.remove('Control')
#drug_names.remove('DMSO')
#drug_names.remove('PBS')

In [10]:
#contY = np.zeros((len(codes),len(uniques)))
#for i in range(len(codes)):
#    j = codes[i]
#    sub_drugtype = drugtype[drugtype['Perturbation']==uniques[j]]
#    index = uniques.index(sub_drugtype['Vehicle'].values[0])
#    contY[i,index]=1
#contY[:,uniques.index('Control')]=0
#contY = contY[:,np.sum(contY,0)>0]

In [11]:
target_label = np.zeros((len(codes),len(uniques)))
for i in range(len(codes)):
    j = codes[i]
    target_label[i,j]=1
    sub_drugtype = drugtype[drugtype['Perturbation']==uniques[j]]
    index = uniques.index(sub_drugtype['Vehicle'].values[0])
    target_label[i,index]=1
target_label[:,uniques.index('Control')]=0
target_label = target_label[:,np.sum(target_label,0)>0]

In [13]:
vehicles=[]
for i in range(len(codes)):
    j = codes[i]
    sub_drugtype = drugtype[drugtype['Perturbation']==uniques[j]]
    vehicles.append(sub_drugtype['Vehicle'].values[0])
adata.obs['Vehicle']=vehicles

In [14]:
#print(drugY.shape)
#print(contY.shape)
print(target_label.shape)
#adata.obsm['X_drug']=drugY
#adata.obsm['X_cont']=contY
adata.obsm['X_target']=target_label

(48604, 15)


In [15]:
donor_codes, donor_uniques = pd.factorize(adata.obs['DONOR'])
donor_uniques = list(donor_uniques)
donorY = np.zeros((len(donor_codes),len(donor_uniques)))
for i in range(len(donor_codes)):
    j = donor_codes[i]
    donorY[i,j]=1
adata.obsm['X_donor']=donorY

In [16]:
print(donorY.shape)

(48604, 39)


In [17]:
import collections
collections.Counter(adata.obs['cell_type'])

Counter({'Neuron': 48604})

In [18]:
collections.Counter(adata.obs['condition'])

Counter({'INFa': 3426,
         'DMSO': 3426,
         'CLOZ': 3426,
         'HALO': 3426,
         'AZT': 3426,
         'EFA': 2452,
         'ATOR': 2629,
         'ISRD': 3426,
         'GluN2a': 3426,
         'Control': 3426,
         'INFy': 3426,
         'PBS': 3426,
         'H2O2': 1046,
         'TNFa': 3426,
         'GLUT': 3426,
         'SIM': 1365})

Train CellCap model

In [19]:
CellCap.setup_anndata(adata,layer="counts",
                      target_key='X_target',donor_key='X_donor')

cellcap = CellCap(adata, n_latent=20, n_layers=3,
                  n_drug=15,n_donor=39,gene_likelihood='nb',n_prog=15)

[34mINFO    [0m Generating sequential column names                                                  
[34mINFO    [0m Generating sequential column names                                                  
[34mINFO    [0m Generating sequential column names                                                  


In [None]:
cellcap.train(max_epochs=1000,batch_size=1024)

GPU available: True (cuda), used: True
TPU available: False, using: 0 TPU cores
IPU available: False, using: 0 IPUs
HPU available: False, using: 0 HPUs
  rank_zero_deprecation(
LOCAL_RANK: 0 - CUDA_VISIBLE_DEVICES: [0]


Epoch 354/1000:  35%|███▌      | 353/1000 [21:37<36:49,  3.41s/it, loss=5.36e+03, v_num=1]  

Plot latent space

In [None]:
z = cellcap.get_latent_embedding(adata)
adata.obsm['X_basal']=z

In [None]:
sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_basal', random_state=0,metric='cosine')
sc.tl.umap(adata, min_dist=0.15)#,method='rapids')
sc.set_figure_params(scanpy=True, dpi=75, dpi_save=75)
sc.pl.umap(adata, color='cell_type', title='',legend_loc='on data')
sc.pl.umap(adata, color='condition', title='')

In [None]:
sc.set_figure_params(scanpy=True, dpi=75, dpi_save=75)
sc.pl.umap(adata, color='DONOR', title='')
sc.pl.umap(adata, color='Sex', title='')
sc.pl.umap(adata, color='Clinical', title='')
sc.pl.umap(adata, color='Age', title='')

In [None]:
##ROC to evaluate basal state
from sklearn.metrics import roc_auc_score
from sklearn.metrics import roc_curve, auc
from sklearn.linear_model import LogisticRegression
from sklearn.model_selection import train_test_split

In [None]:
adata.obsm['X_latent']=z
fpr=dict()
tpr=dict()
roc_auc=dict()
conditions = list(set(adata[adata.obs['condition']!='Control'].obs['condition'].values))
for c in conditions:
    subad = adata[np.logical_or(adata.obs['condition']=='Control',adata.obs['condition']==c)]
    y, ycode = pd.factorize(subad.obs['condition'])
    y[subad.obs['condition']=='Control']=0
    y[subad.obs['condition']==c]=1
    X = subad.obsm['X_latent']
    random_state = np.random.RandomState(0)

    # shuffle and split training and test sets
    X_train, X_test, y_train, y_test = train_test_split(X, y, test_size=0.5, random_state=0)

    # Learn to predict each class against the other
    classifier = LogisticRegression(random_state=random_state)
    y_score = classifier.fit(X_train, y_train).decision_function(X_test)

    fpr[c], tpr[c], _ = roc_curve(y_test, y_score, pos_label=classifier.classes_[1])
    roc_auc[c] = auc(fpr[c], tpr[c])

In [None]:
conditions.sort()

In [None]:
colors = list(sns.color_palette("Paired"))+list(sns.color_palette("hls", 8))

In [None]:
sc.set_figure_params(scanpy=True, dpi=200, dpi_save=200)
plt.figure()
for i, color in zip(range(len(conditions)), colors):
    plt.plot(
        fpr[conditions[i]],
        tpr[conditions[i]],
        color=color,
        lw=1,
        #linestyle=":",
        label="{0}".format(conditions[i]),
        #label="{0} (AUC = {1:0.2f})".format(conditions[i], roc_auc[conditions[i]]),
    )
plt.grid(False)
plt.plot([0, 1], [0, 1], "k--", lw=1)
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.0])
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("")
legend = plt.legend(loc="lower right",prop={'size': 4.5})
#legend = plt.legend(bbox_to_anchor=(1.1, 1.05),prop={'size': 5})
legend.get_frame().set_facecolor('none')
plt.show()

In [None]:
for i, color in zip(range(len(conditions)), colors):
    print("{0} (AUC = {1:0.4f})".format(conditions[i], roc_auc[conditions[i]]))

In [None]:
zP, zAttn = cellcap.get_pert_embedding(adata)
zA = cellcap.get_embedding(adata)

adata.obsm['X_pert']=zP
adata.obsm['X_attn']=zAttn
adata.obsm['X_latent']=zA

In [None]:
n_latent = 20
n_prog=15
attention = ['A'+str(i) for i in range(1,n_prog+1)]

In [None]:
ard = []
for i in cellcap.history.keys():
    if 'alpha' in i and 'train' in i:
        ard.append(cellcap.history[i].iloc[:,0].values)
    #if 'h_pq' in i and 'train' in i:
    #    ard.append(cellcap.history[i].iloc[:,0].values)

ard = torch.from_numpy(np.array(ard).astype('float32'))
#ard = ard.sigmoid()
#ard = 1/ard
ard = torch.Tensor.cpu(ard).detach().numpy()
e = ard.shape[1]
ard = ard.reshape(15,n_prog,e)

In [None]:
cloz = ard[drug_names.index('CLOZ'),:]

sc.set_figure_params(scanpy=True, dpi=150, dpi_save=150)
plt.figure()
for p, color in zip(range(n_prog), colors):
    plt.plot(
        [i for i in range(0,cloz.shape[1])],
        cloz[p,:],
        color=color,
        lw=1,
        label="{0}".format(p+1),
    )
plt.grid(False)
plt.title("")
#legend = plt.legend(loc="lower right",prop={'size': 4.5})
legend = plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
legend.get_frame().set_facecolor('none')
plt.show()

In [None]:
cloz = ard[drug_names.index('DMSO'),:,:]

sc.set_figure_params(scanpy=True, dpi=150, dpi_save=150)
plt.figure()
for p, color in zip(range(n_prog), colors):
    plt.plot(
        [i for i in range(0,cloz.shape[1])],
        cloz[p,:],
        color=color,
        lw=1,
        label="{0}".format(p+1),
    )
plt.grid(False)
plt.title("")
#legend = plt.legend(loc="lower right",prop={'size': 4.5})
legend = plt.legend(bbox_to_anchor=(1.02, 1), loc='upper left', borderaxespad=0)
legend.get_frame().set_facecolor('none')
plt.show()

In [None]:
drug_embedding = cellcap.get_pert_loadings()
weights = cellcap.get_loadings()
drug_loading = np.matmul(weights,drug_embedding.T)
drug_loading = pd.DataFrame(drug_loading)
drug_loading.index = adata.var.index
drug_loading.columns = drug_names

In [None]:
w = drug_loading['CLOZ'].sort_values(ascending = False)
w.index[:100]

In [None]:
drug_embedding = cellcap.get_resp_loadings()
weights = cellcap.get_loadings()
drug_loading = np.matmul(weights,drug_embedding.T)
drug_loading = pd.DataFrame(drug_loading)
drug_loading.index = adata.var.index

In [None]:
w = drug_loading.iloc[:,14].sort_values(ascending = False)
w.index[:100]

In [None]:
for d in drug_names:
    cloz = adata[adata.obs['condition']==d]
    attn = pd.DataFrame(cloz.obsm['X_attn'])
    attn.columns = attention
    #attn.iloc[0,:]=0
    #attn.iloc[1,:]=1

    for a in attention:
        x = attn[a].values
        #x[x<=0.5]=np.nan
        x[0]=0
        x[1]=1
        cloz.obs[a]=x
    
    sc.set_figure_params(scanpy=True, dpi=50, dpi_save=50)
    print(d)
    sc.pl.umap(cloz, color=attention,ncols=15, frameon=False,cmap='coolwarm')#,size=30)
    #sc.pl.umap(cloz, color=['VIM','MVD'],ncols=5, frameon=False,cmap='coolwarm')#,size=30)

In [None]:
sc.pp.neighbors(adata, n_neighbors=15, use_rep='X_latent', random_state=0,metric='cosine')
sc.tl.umap(adata, min_dist=0.15)#,method='rapids')
sc.set_figure_params(scanpy=True, dpi=75, dpi_save=75)
sc.pl.umap(adata, color='cell_type', title='',legend_loc='on data')
sc.pl.umap(adata, color='condition', title='',legend_loc='on data',legend_fontsize=7.5)