In [None]:
# import warnings
# warnings.filterwarnings('ignore')

import GEOparse
from tqdm import tqdm
import urllib.request
import random
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata

from scvi.dataset import AnnDatasetFromAnnData

import torch
import matplotlib.pyplot as plt
import tensorflow as tf
import sys
import umap

from utils_helper import *

seed = 345
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
tf.set_random_seed(seed)


gpus = ["1"]
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)
device = 'cuda:0'

# Load Data

In [None]:
path = "path_to_dataset.h5ad"
adata = sc.read_h5ad(path)

In [None]:
id2age = {10:19,11:19,12:24,15:30,17:41,20:31,31:21,36:27,48:26,55:54,58:22,67:22,84:29,104:87,118:51,127:82,133:42,134:42,135:18,142:47,149:77,150:52,167:30,173:20,183:64,184:34,185:48,215:43,216:55,225:44,250:26,251:44,305:32,315:53}
adata.obs['age'] = adata.obs['patient_id'].map(id2age)
adata.obs['age_bin'] = (adata.obs['age'].astype('int32') / 10).astype('int32').astype('category')

In [None]:
adata.obs['cell_types'] = adata.obs['cell_group'].values
adata.obs['cell_type'] = adata.obs['cell_group'].values
adata.obs['labels'] = adata.obs['cell_group'].astype('category').cat.codes.values

adata.obs['batch_name'] = adata.obs['age'].values
adata.obs['batch'] = adata.obs['batch_name']
adata.obs['batch_indices'] =  adata.obs['batch']

n_labels = len(adata.obs['cell_types'].unique())
n_batch = len(adata.obs['batch'].unique())

In [None]:
# sc.pp.subsample(adata, fraction=0.2, random_state = seed) # todo

In [None]:
dataset = AnnDatasetFromAnnData(adata)

# Preprocess

In [None]:
# sc.pp.log1p(adata)

# Latent Inference

In [None]:
n_epochs = 50
lr = 0.001
eps = 1e-8
use_batches = True
use_cuda = True
n_latent = 10
%matplotlib inline

In [None]:
vae = VAE(dataset.nb_genes, n_batch=dataset.n_batches * use_batches, n_latent=n_latent, n_layers = 2, n_hidden=64)

In [None]:
# disc = Discriminator(n_latent, [2*n_latent], n_batch).to('cuda:0')
disc = Regressor(n_latent, [2*n_latent], n_batch).to('cuda:0')

In [None]:
trainer = GANTrainer(
    vae, 
    disc,
    dataset,
    train_size=0.9999999999999,  test_size=None, # train_size=0.9, test_size=0.05,
    use_cuda=use_cuda,
    frequency=5,
    seed = seed
)


In [None]:
history = trainer.train(n_epochs=150, lr= lr*1, eps=eps, disc_lr= lr * 1, enc_lr = lr* 0.0 )

In [None]:
history = trainer.train(n_epochs=50, lr= lr*1, eps=eps, disc_lr= lr * 1, enc_lr = lr* 0.05)

In [None]:
elbo_train = history[0] 
x = np.linspace(0, len(elbo_train), len(elbo_train))
plt.plot(x, elbo_train)

In [None]:
posterior = trainer.create_posterior(trainer.model, dataset, indices=np.arange(len(dataset)))
latent, batches, labels = posterior.sequential().get_latent()

In [None]:
adata.obsm["X_scVI"] = latent

In [None]:
latent_tensor = torch.tensor(latent, device='cuda:0')

In [None]:
batch_prediction = np.argmax(disc(latent_tensor).detach().cpu().numpy(), axis=1)

# Clustering Scores by kmeans

In [None]:
def calc_scores(input_posterior):
    latent, batches, labels = input_posterior.sequential().get_latent()
    print("Entropy of batch mixing :", entropy_batch_mixing(latent,batches))
    print("Clustering ARI = {}".format(clustering_scores(dataset.n_labels, labels, latent)))

In [None]:
print('Train Set:')
calc_scores(trainer.train_set)

In [None]:
print('Test Set:')
calc_scores(trainer.test_set)

In [None]:
print('Validation Set:')
calc_scores(trainer.validation_set)

In [None]:
# posterior.clustering_scores()

# t-SNE

In [None]:
sc.tl.tsne(adata, use_rep='X_scVI', n_pcs=2)

In [None]:
adata.obs['batch_pred_cat'] = pd.Series(batch_prediction).astype('category').values

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.tsne(adata, color=["cell_types"], ax=ax, show=show_plot)
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["batch_name"], ax=ax, show=show_plot, size = 20)
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["batch_pred_cat"], ax=ax, show=show_plot ,size = 20)

fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["age"], ax=ax, show=show_plot ,size = 20)

fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["condition"], ax=ax, show=show_plot)

In [None]:
dataset_name = 'Turecki'
method = 'scGAN'
# method = '$scGAN^{-}$(No Adversarial Net)'

fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
sc.pl.tsne(adata, color=["cell_type"], ax=ax, title='{} - cell type'.format(method))
fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_celltype.png'.format(dataset_name,method), bbox_inches = 'tight')

fig, ax = plt.subplots(figsize=(7, 6),  dpi=150)
sc.pl.tsne(adata, color=["age"], ax=ax, title='{} - Batch(Age)'.format(method))
fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_batch.png'.format(dataset_name,method), bbox_inches = 'tight')

fig, ax = plt.subplots(figsize=(7, 6) , dpi=150)
sc.pl.tsne(adata, color=["condition"], ax=ax,  title='{} - Condition'.format(method) )
fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_condition.png'.format(dataset_name,method), bbox_inches = 'tight')

In [None]:
ari_kmeans , clusters_kmeans = clustering_scores_2(dataset.n_labels, labels, latent, prediction_algorithm='knn')

In [None]:
print(ari_kmeans)
adata.obs['clusters_kmeans'] = clusters_kmeans.astype('object')
show_plot = True
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=['clusters_kmeans'], ax=ax, show=show_plot)

In [None]:
from tqdm import tqdm
ari_max, clusters_max = 0.0 , None
ari_max_list = []
c_n = list(range(2, dataset.n_labels+1))
for c in c_n:
    ari_perC_list = []
    for i in tqdm(range(20)):
        ari_gmm = clustering_scores(c, labels, latent, prediction_algorithm='gmm')
        ari_perC_list.append(ari_gmm)
    
    ari_max_list.append(sum(ari_perC_list)/len(ari_perC_list))

#     ari_max_list.append(ari_max_perC)


In [None]:
plt.figure(figsize=(8,6))
plt.plot(c_n,ari_max_list)
plt.ylabel('ARI (with GMM)')
plt.xlabel('Number of Clusters')
plt.show()

In [None]:
print(sum(ari_max_list[3:])/len(ari_max_list[3:]))

In [None]:
sc.pp.neighbors(adata, use_rep="X_scVI", n_neighbors=30)
sc.tl.louvain(adata, resolution=0.20)

In [None]:
louvain_color = np.ones(adata.shape[0])
for cluster, p in enumerate(pvalues):
    temp_mask = adata.obs['louvain'].values==str(cluster)
    louvain_color[temp_mask] = -np.log10(p)
adata.obs['louvain_color'] = louvain_color

show_plot = True
fig, ax = plt.subplots(figsize=(7, 6),  dpi=150)
sc.pl.tsne(adata, color=['louvain_color'], ax=ax, show=show_plot, color_map='Reds',  title='{} - Cluster Enrichment'.format(method))
fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_louvain_bypvalue.png'.format(dataset_name,method), bbox_inches = 'tight')


In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=['louvain'], ax=ax, show=show_plot)

In [None]:
from sklearn.metrics import adjusted_rand_score as ARI
ari_score = ARI(labels, adata.obs['louvain'])
print('Lovain ARI score={}'.format(ari_score))

# Differential Gene Expression 

In [None]:
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data = self.data.iloc[index].values
        label = self.label.iloc[index]
            
        return data, label
    

In [None]:
import torch.nn.functional as F
def train(model, train_loader, optimizer, epochs):
    model.train()
    history = []
    for epoch in tqdm(range(1, epochs)):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)

            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                history.append(loss.item())
    return history

In [None]:
preds_train = mdd_classifier(torch.tensor(adata_train.obsm['X_scVI']).to(device))
preds_test = mdd_classifier(torch.tensor(adata_test.obsm['X_scVI']).to(device))

In [None]:
acc_train = (preds_train.argmax(dim=1).detach().cpu().numpy() == adata_train.obs['condition'].cat.codes.astype('long')).mean()
print("acc train: {}".format(acc_train))

acc_test = (preds_test.argmax(dim=1).detach().cpu().numpy() == adata_test.obs['condition'].cat.codes.astype('long')).mean()
print("acc test: {} ".format(acc_test))

## Automatic DE per cluster

In [None]:
psygenet = pd.read_csv("/home/mcb/users/mbahra5/project/scVI/psygenet_v02.txt", sep='\t')
disgenet = pd.read_csv("/home/mcb/users/mbahra5/project/scVI/curated_gene_disease_associations.tsv.gz", sep='\t')
mask1 = psygenet['PsychiatricDisorder']=='Depressive disorders' 
mask2 = psygenet['PsychiatricDisorder']=='Schizophrenia spectrum and other psychotic disorders' 
mask1_dis = disgenet['diseaseName']=='Major Depressive Disorder' 
genes = adata.var.index.values

In [None]:

overlaps=[]
overlaps_dis = []
for cluster in adata.obs['louvain'].cat.categories:
    adata_filter = adata[(adata.obs['louvain']==cluster).values]
    adata_train = adata_filter[:-5]
    adata_test = adata_filter[-5:]
    
    mydataset = MyDataset(data = pd.DataFrame(adata_train.obsm['X_scVI']) , label=adata_train.obs['condition'].cat.codes.astype('long'))
    train_loader = torch.utils.data.DataLoader(mydataset, batch_size=512)
    mdd_classifier = Discriminator(n_latent, [2*n_latent, 2*n_latent], 2).to('cuda:0')
    
    optimizerMDD= torch.optim.Adam(mdd_classifier.parameters(), lr = lr)
    history = train(mdd_classifier, train_loader, optimizerMDD, epochs=30)
    
    gene_dataset = MyDataset(data = pd.DataFrame(adata_train.X.toarray()) , label=adata_train.obs['batch_indices'].astype('long'))
    gene_dataset_loader = torch.utils.data.DataLoader(gene_dataset, batch_size=512)

    grads = []
    for sample_batch, batch_index in tqdm(gene_dataset_loader):
        vae.zero_grad()
        optimizerMDD.zero_grad()
        
        sample_batch = sample_batch.to(device)
        batch_index = batch_index.to(device)
        sample_batch.requires_grad=True
        batch_index = batch_index.reshape((-1,1))
        
        reconst_loss, kl_divergence, z = vae(sample_batch, batch_index)        
        output = mdd_classifier(z)
        output = torch.exp(output)
        output.sum(dim=0)[0].backward()
        grads.append(sample_batch.grad.detach().cpu().numpy())    
    
    grads_per_gene_abs = np.absolute(np.concatenate(grads).sum(axis=0))
    print('grad shape for cluster {} is ={}'.format(cluster,grads_per_gene_abs.shape))
    genes_top = genes[grads_per_gene_abs.argsort()[-50:]]
    
    overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
    overlaps.append(overlap)
    overlaps_dis.append(overlap_dis)
    print(overlap)
    print(overlap_dis)

In [None]:
for i, ov in enumerate(overlaps):
    print("DEGs for cluster {}:{}".format(i, ', '.join(list(ov))))


## Automatic DE across all clusters

In [None]:

overlaps=[]
overlaps_dis = []
grads_all_clusters =[]
for cluster in adata.obs['louvain'].cat.categories:
    adata_filter = adata[(adata.obs['louvain']==cluster).values]
    adata_train = adata_filter[:-5]
    adata_test = adata_filter[-5:]
    
    mydataset = MyDataset(data = pd.DataFrame(adata_train.obsm['X_scVI']) , label=adata_train.obs['condition'].cat.codes.astype('long'))
    train_loader = torch.utils.data.DataLoader(mydataset, batch_size=512)
    mdd_classifier = Discriminator(n_latent, [2*n_latent, 2*n_latent], 2).to('cuda:0')
    
    optimizerMDD= torch.optim.Adam(mdd_classifier.parameters(), lr = lr)
    history = train(mdd_classifier, train_loader, optimizerMDD, epochs=30)
    
    gene_dataset = MyDataset(data = pd.DataFrame(adata_train.X.toarray()) , label=adata_train.obs['batch_indices'].astype('long'))
    gene_dataset_loader = torch.utils.data.DataLoader(gene_dataset, batch_size=512)

    grads = []
    for sample_batch, batch_index in tqdm(gene_dataset_loader):
        vae.zero_grad()
        optimizerMDD.zero_grad()
        
        sample_batch = sample_batch.to(device)
        batch_index = batch_index.to(device)
        sample_batch.requires_grad=True
        batch_index = batch_index.reshape((-1,1))
        
        reconst_loss, kl_divergence, z = vae(sample_batch, batch_index)        
        output = mdd_classifier(z)
        output = torch.exp(output)
        output.sum(dim=0)[0].backward()
        grads.append(sample_batch.grad.detach().cpu().numpy())    
    
    grads_per_gene_abs = np.absolute(np.concatenate(grads).mean(axis=0))
    grads_all_clusters.append(grads_per_gene_abs)


In [None]:
print('Overlaps with PsyGeNet:')
for top in [5, 10 , 50]:
    grads_agg = np.array(grads_all_clusters).max(axis=0)
    genes_top = genes[grads_agg.argsort()[-top:]]

    overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
#     print("DEGs for top {}: {}".format(top, ', '.join(list(overlap))))
    print("DEGs for top {}: {}".format(top, len(overlap)))
    

### Overlap with GWAS

In [None]:
gwas = pd.read_csv("/home/mcb/users/mbahra5/project/scVI/gwas-association-downloaded_2020-04-22-EFO_0003761-withChildTraits.tsv", sep='\t')
mask_gwas = gwas['DISEASE/TRAIT'].apply(lambda x : ('Major depressive disorder' in x) | ('major depressive disorder' in x))

In [None]:
gwas = gwas[mask_gwas]

In [None]:
gwas_genes = []
for g in gwas['MAPPED_GENE']:    
    g = str(g).replace(' - ',', ').replace('; ',', ')
    if ', ' in g:
        gwas_genes.extend(g.split(', '))

In [None]:
print('Overlaps with GWAS:')
for top in [5, 10 , 50, 100]:
    grads_agg = np.array(grads_all_clusters).sum(axis=0)
    genes_top = genes[grads_agg.argsort()[-top:]]

    overlap = np.intersect1d(gwas_genes, genes_top)
    overlap_dis = np.intersect1d(gwas_genes, genes_top)
#     print("DEGs for top {}: {}".format(top, ', '.join(list(overlap))))
    print("DEGs for top {}: {}".format(top, len(overlap)))
    print(overlap)

# DE by LMM

In [None]:
sc.pp.log1p(adata)

In [None]:
from pymer4.models import Lm, Lmer

# adata_log = sc.pp.log1p(adata, copy=True)
adata_log = adata
pvals=[]

for cluster in adata.obs['louvain'].cat.categories:
    print('cluster=' + str(cluster))
    mask = (adata.obs['louvain']==cluster).values
    adata_log = adata[mask]
    
    df = pd.DataFrame(adata_log.obsm['X_'], columns=["IV"+str(x) for x in range(10)], index=adata_log.obs.index)
    df['condition'] = adata_log.obs['condition'].values
    df['patient_id'] = adata_log.obs['patient_id'].values
    
    sc.pp.highly_variable_genes(adata_log, n_top_genes=5000)
    adata_log = adata_log[:,adata_log.var.highly_variable]
    genes_highvar = adata_log.var.index.values
    
    pvals_per_cluster = []
    for i , gene_name in enumerate(genes_highvar):
        gene = adata_log[:,i]
        df['gene'] = gene.X
    #     df['gene'] = np.random.rand(gene.X.shape[0])
    #     df['gene'] = df['condition'].cat.codes + np.random.rand(gene.shape[0])

        model = Lm('gene ~ 1 + condition + age_bin + patient_id',data=df)
#         model = Lmer('gene ~ 1 + condition + (1|patient_id) + IV0 + IV1 + IV2 + IV3 + IV4 + IV5 + IV6 + IV7 + IV8 + IV9',data=df)
        model.fit(summarize=False)
        pvals_per_cluster.append([model.coefs['P-val'][1] , model.coefs['Estimate'][1], gene_name])
        
    pvals.append(pvals_per_cluster)
    

# DE across all clusters by LMM

In [None]:
gene_scores = []
for k , pvals_per_cluster in enumerate(pvals):
    gene_score = pd.DataFrame(pvals_per_cluster,columns=['pval','estimate','gene'])
    gene_scores.append(gene_score)
  

In [None]:
genes_pvalues = pd.concat(gene_scores, axis=0).groupby('gene').min()

In [None]:
for top in [5, 10 , 50 , 100]:
    genes_top = genes_pvalues.sort_values('pval').index[:top].values
    overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
    print("DEGs for top {}: {}".format(top, len(overlap)))

# Cluster Enrichment Analysis

In [None]:
from scipy.stats import hypergeom
suicide_numbers_total = (adata.obs.condition=='Suicide').values.sum()
pvalues = []
for cluster in adata.obs['louvain'].cat.categories:
    adata_filter = adata[(adata.obs['louvain']==cluster).values]
    suicide_numbers = (adata_filter.obs.condition=='Suicide').values.sum()
    p_value = hypergeom.sf(k = suicide_numbers-1, M = adata.shape[0] , n = suicide_numbers_total, N = adata_filter.shape[0], loc=0)
    print('P-value of Hyper test for cluster {} = {}'.format(cluster,p_value))
    pvalues.append(p_value)

# UMAP

In [None]:
# import warnings
# warnings.filterwarnings('ignore')
sc.pp.neighbors(adata, use_rep="X_scVI", n_neighbors=15)
sc.tl.umap(adata, min_dist=0.1)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.umap(adata, color=["cell_type"], ax=ax, show=show_plot)
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.umap(adata, color=["batch_name"], ax=ax, show=show_plot)
