In [None]:
import GEOparse
from tqdm import tqdm
import urllib.request
import random
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata

import torch
import matplotlib.pyplot as plt
import tensorflow as tf
import sys
import umap

from utils_helper import VAE, Discriminator, Regressor, GANTrainer, entropy_batch_mixing, clustering_scores, GeneDataset


seed = 345
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
tf.set_random_seed(seed)


gpus = ["6"]
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)
device = 'cuda:0'

# Load Data

In [None]:
path = "/home/mcb/users/mbahra5/project/data/turecki_types_all.h5ad"
adata = sc.read_h5ad(path)

In [None]:
# to use a subsample of data uncomment this:
sc.pp.subsample(adata, fraction=0.2, random_state = seed)

In [None]:
# Create batch labels using the age
id2age = {10:19,11:19,12:24,15:30,17:41,20:31,31:21,36:27,48:26,55:54,58:22,67:22,84:29,104:87,118:51,127:82,133:42,134:42,135:18,142:47,149:77,150:52,167:30,173:20,183:64,184:34,185:48,215:43,216:55,225:44,250:26,251:44,305:32,315:53}
adata.obs['age'] = adata.obs['patient_id'].map(id2age).astype('float')
adata.obs['age_bin'] = (adata.obs['age'].astype('int32') / 10).astype('int32').astype('category') 

In [None]:
adata.obs['cell_type'] = adata.obs['cell_group'].values
adata.obs['labels'] = adata.obs['cell_group'].astype('category').cat.codes.values
adata.obs['batch'] = adata.obs['age_bin'].astype('category').cat.codes.astype('long').values

In [None]:
dataset = GeneDataset(adata.X, adata.obs.labels, adata.obs.batch, adata.obs.age)

# Latent Inference

In [None]:
n_epochs = 50
lr = 0.001
eps = 1e-8
use_batches = True
use_cuda = True
n_latent = 10
batch_size = 128
%matplotlib inline

In [None]:
vae = VAE(dataset.nb_genes, n_batch=dataset.n_batches * use_batches, n_latent=n_latent, n_layers = 2,
          n_hidden=64).cuda(device)

In [None]:
disc = Regressor(n_latent, [2*n_latent], dataset.n_batches).to('cuda:0').cuda(device)

In [None]:
trainer = GANTrainer('continuous', vae, disc, dataset, device, batch_size)

In [None]:
# Pretraining
history = trainer.train(n_epochs=30, lr= lr*1, eps=eps, disc_lr= lr * 0.0, enc_lr = lr* 0.0 )

In [None]:
elbo_train = history[0]
x = np.linspace(0, len(elbo_train), len(elbo_train))
plt.plot(x, elbo_train)

In [None]:
# Training with Adversarial loss
history = trainer.train(n_epochs=30, lr= lr*1, eps=eps, disc_lr= lr * 1, enc_lr = lr* 0.05)

In [None]:
latent, labels, batches = [item.detach().cpu().numpy() for item in trainer.get_latent()]

In [None]:
adata.obsm["X_scGAN"] = latent

# t-SNE

In [None]:
sc.tl.tsne(adata, use_rep='X_scGAN', n_pcs=2)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["cell_group"], ax=ax, show=show_plot)

# fig, ax = plt.subplots(figsize=(9, 8))
# sc.pl.tsne(adata, color=["batch"], ax=ax, show=show_plot, size = 20)

fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["age"], ax=ax, show=show_plot ,size = 20)

fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=["condition"], ax=ax, show=show_plot)

# Scores

In [None]:
print("Entropy of batch mixing :", entropy_batch_mixing(latent, batches))

## Kmeans Clustering Score

In [None]:
print("Clustering ARI = {}".format(clustering_scores(dataset.n_labels, labels, latent)))

## Louvain Clustering Score

In [None]:
sc.pp.neighbors(adata, use_rep="X_scGAN", n_neighbors=30)
sc.tl.louvain(adata, resolution=0.20)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=['louvain'], ax=ax, show=show_plot)

In [None]:
from sklearn.metrics import adjusted_rand_score as ARI
ari_score = ARI(labels, adata.obs['louvain'])
print("Louvain Clustering ARI = {}".format(ari_score))

# Significant Genes

In [None]:
from torch.utils.data import DataLoader, Dataset

class MyDataset(Dataset):    
    def __init__(self, data, label):
        self.data = data
        self.label = label
        
    def __len__(self):
        return len(self.data)
    
    def __getitem__(self, index):
        data = self.data.iloc[index].values
        label = self.label.iloc[index]
            
        return data, label
    

In [None]:
import torch.nn.functional as F
def train(model, train_loader, optimizer, epochs):
    model.train()
    history = []
    for epoch in tqdm(range(1, epochs)):
        model.train()
        for batch_idx, (data, target) in enumerate(train_loader):
            data, target = data.to(device), target.to(device)
            optimizer.zero_grad()
            output = model(data)

            loss = F.nll_loss(output, target)
            loss.backward()
            optimizer.step()
            if batch_idx % 10 == 0:
                history.append(loss.item())
            
    return history


# Automatic DE 

In [None]:
psygenet = pd.read_csv("/home/mcb/users/mbahra5/project/scVI/psygenet_v02.txt", sep='\t')
disgenet = pd.read_csv("/home/mcb/users/mbahra5/project/scVI/curated_gene_disease_associations.tsv.gz", sep='\t')
mask1 = psygenet['PsychiatricDisorder']=='Depressive disorders' 
mask2 = psygenet['PsychiatricDisorder']=='Schizophrenia spectrum and other psychotic disorders' 
mask1_dis = disgenet['diseaseName']=='Major Depressive Disorder' 
genes = adata.var.index.values

In [None]:

overlaps=[]
overlaps_dis = []
for cluster in adata.obs['louvain'].cat.categories:
    adata_filter = adata[(adata.obs['louvain']==cluster).values]
    adata_train = adata_filter[:-5]
    adata_test = adata_filter[-5:]
    
    mydataset = MyDataset(data = pd.DataFrame(adata_train.obsm['X_scGAN']) , label=adata_train.obs['condition'].cat.codes.astype('long'))
    train_loader = torch.utils.data.DataLoader(mydataset, batch_size=512)
    mdd_classifier = Discriminator(n_latent, [2*n_latent, 2*n_latent], 2).to('cuda:0')
    
    optimizerMDD= torch.optim.Adam(mdd_classifier.parameters(), lr = lr)
    history = train(mdd_classifier, train_loader, optimizerMDD, epochs=30)
    
    gene_dataset = MyDataset(data = pd.DataFrame(adata_train.X.toarray()) , label=adata_train.obs['batch'].astype('long'))
    gene_dataset_loader = torch.utils.data.DataLoader(gene_dataset, batch_size=512)

    grads = []
    for sample_batch, batch_index in tqdm(gene_dataset_loader):
        vae.zero_grad()
        optimizerMDD.zero_grad()
        
        sample_batch = sample_batch.to(device)
        batch_index = batch_index.to(device)
        sample_batch.requires_grad=True
        batch_index = batch_index.reshape((-1,1))
        
        z = vae.sample_from_posterior_z(sample_batch, give_mean=True)
        output = mdd_classifier(z)
        output = torch.exp(output)
        output.sum(dim=0)[0].backward()
        grads.append(sample_batch.grad.detach().cpu().numpy())    
    
    grads_per_gene_abs = np.absolute(np.concatenate(grads).sum(axis=0))
    print('grad shape for cluster {} is ={}'.format(cluster,grads_per_gene_abs.shape))
    genes_top = genes[grads_per_gene_abs.argsort()[-50:]]
    
    overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
    overlaps.append(overlap)
    overlaps_dis.append(overlap_dis)
    print(overlap)
    print(overlap_dis)

In [None]:
for i, ov in enumerate(overlaps):
    print("DEGs for cluster {}:{}".format(i, ', '.join(list(ov))))


# Automatic DEGs across all clusters

In [None]:

overlaps=[]
overlaps_dis = []
grads_all_clusters =[]
for cluster in adata.obs['louvain'].cat.categories:
    adata_filter = adata[(adata.obs['louvain']==cluster).values]
    adata_train = adata_filter[:-5]
    adata_test = adata_filter[-5:]
    
    mydataset = MyDataset(data = pd.DataFrame(adata_train.obsm['X_scGAN']) , label=adata_train.obs['condition'].cat.codes.astype('long'))
    train_loader = torch.utils.data.DataLoader(mydataset, batch_size=512)
    mdd_classifier = Discriminator(n_latent, [2*n_latent, 2*n_latent], 2).to('cuda:0')
    
    optimizerMDD= torch.optim.Adam(mdd_classifier.parameters(), lr = lr)
    history = train(mdd_classifier, train_loader, optimizerMDD, epochs=30)
    
    gene_dataset = MyDataset(data = pd.DataFrame(adata_train.X.toarray()) , label=adata_train.obs['batch'].astype('long'))
    gene_dataset_loader = torch.utils.data.DataLoader(gene_dataset, batch_size=512)

    grads = []
    for sample_batch, batch_index in tqdm(gene_dataset_loader):
        vae.zero_grad()
        optimizerMDD.zero_grad()
        
        sample_batch = sample_batch.to(device)
        batch_index = batch_index.to(device)
        sample_batch.requires_grad=True
        batch_index = batch_index.reshape((-1,1))
        
        z = vae.sample_from_posterior_z(sample_batch, give_mean=True)
        output = mdd_classifier(z)
        output = torch.exp(output)
        output.sum(dim=0)[0].backward()
        grads.append(sample_batch.grad.detach().cpu().numpy())    
    
    grads_per_gene_abs = np.absolute(np.concatenate(grads).mean(axis=0))
    grads_all_clusters.append(grads_per_gene_abs)


In [None]:
print('Overlaps with PsyGeNet:')
for top in [5, 10 , 50]:
    grads_agg = np.array(grads_all_clusters).max(axis=0)
    genes_top = genes[grads_agg.argsort()[-top:]]
    overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    print("DEGs for top {}: {}".format(top, len(overlap)))
    

In [None]:
print('Overlaps with DisGeNet:')
for top in [50, 100 ]:
    grads_agg = np.array(grads_all_clusters).max(axis=0)
    genes_top = genes[grads_agg.argsort()[-top:]]
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
    print("DEGs for top {}: {}".format(top, len(overlap_dis)))

In [None]:
top = 50
grads_agg = np.array(grads_all_clusters).max(axis=0)
genes_top = pd.DataFrame(genes[grads_agg.argsort()[-top:]], columns=['gene'])
genes_top['grad'] = grads_agg[grads_agg.argsort()][-top:]
genes_top['in_pygenet'] = genes_top['gene'].apply(lambda x: x in psygenet[mask1|mask2]['Gene_Symbol'].values)
genes_top.sort_values('grad', ascending=False)

## Overlap with GWAS

In [None]:
gwas = pd.read_csv("/home/mcb/users/mbahra5/project/scVI/gwas-association-downloaded_2020-04-22-EFO_0003761-withChildTraits.tsv", sep='\t')
# mask_gwas = gwas['DISEASE/TRAIT']=='Major depressive disorder'
mask_gwas = gwas['DISEASE/TRAIT'].apply(lambda x : ('Major depressive disorder' in x) | ('major depressive disorder' in x))
gwas = gwas[mask_gwas]

In [None]:
gwas_genes = []
for g in gwas['MAPPED_GENE']:    
    g = str(g).replace(' - ',', ').replace('; ',', ')
    if ', ' in g:
        gwas_genes.extend(g.split(', '))

In [None]:
print('Overlaps with GWAS:')
for top in [5, 10 , 50, 100]:
    grads_agg = np.array(grads_all_clusters).sum(axis=0)
    genes_top = genes[grads_agg.argsort()[-top:]]

    overlap = np.intersect1d(gwas_genes, genes_top)
    overlap_dis = np.intersect1d(gwas_genes, genes_top)
#     print("DEGs for top {}: {}".format(top, ', '.join(list(overlap))))
    print("DEGs for top {}: {}".format(top, len(overlap)))
    print(overlap)

# Cluster Enrichment Analysis

In [None]:
from scipy.stats import hypergeom
suicide_numbers_total = (adata.obs.condition=='Suicide').values.sum()
pvalues = []
for cluster in adata.obs['louvain'].cat.categories:
    adata_filter = adata[(adata.obs['louvain']==cluster).values]
    suicide_numbers = (adata_filter.obs.condition=='Suicide').values.sum()
    p_value = hypergeom.sf(k = suicide_numbers-1, M = adata.shape[0] , n = suicide_numbers_total, N = adata_filter.shape[0], loc=0)
    print('P-value of Hyper test for cluster {} = {}'.format(cluster,p_value))
    pvalues.append(p_value)

In [None]:
louvain_color = np.ones(adata.shape[0])
for cluster, p in enumerate(pvalues):
    temp_mask = adata.obs['louvain'].values==str(cluster)
    louvain_color[temp_mask] = -np.log10(p)
adata.obs['louvain_color'] = louvain_color

show_plot = True
fig, ax = plt.subplots(figsize=(5, 4),  dpi=150)
sc.pl.tsne(adata, color=['louvain_color'], ax=ax, show=show_plot, color_map='Reds')
# fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_louvain_bypvalue.png'.format(dataset_name,method), bbox_inches = 'tight')


# DE across all clusters by Lmer

In [None]:
from pymer4.models import Lm, Lmer

adata_log = sc.pp.log1p(adata, copy=True)
# adata_log = adata
pvals=[]

for cluster in adata_log.obs['louvain'].cat.categories:
    print('cluster=' + str(cluster))
    mask = (adata_log.obs['louvain']==cluster).values
    adata_log = adata_log[mask]
    
#     df = pd.DataFrame(adata_log.obsm['X_scGAN'], columns=["IV"+str(x) for x in range(10)], index=adata_log.obs.index)
    df = pd.DataFrame()
    df['condition'] = adata_log.obs['condition'].values
    df['patient_id'] = adata_log.obs['patient_id'].values
    
#     sc.pp.highly_variable_genes(adata_log, n_top_genes=5000)
#     adata_log = adata_log[:,adata_log.var.highly_variable]
    genes_highvar = adata_log.var.index.values
    
    pvals_per_cluster = []
    for i , gene_name in enumerate(genes_highvar):
        gene = adata_log[:,i]
        df['gene'] = gene.X
    #     df['gene'] = np.random.rand(gene.X.shape[0])
    #     df['gene'] = df['condition'].cat.codes + np.random.rand(gene.shape[0])

        model = Lm('gene ~ 1 + condition + (1|patient_id)',data=df)
#         model = Lmer('gene ~ 1 + condition + (1|patient_id) + IV0 + IV1 + IV2 + IV3 + IV4 + IV5 + IV6 + IV7 + IV8 + IV9',data=df)
        model.fit(summarize=False)
        pvals_per_cluster.append([model.coefs['P-val'][1] , model.coefs['Estimate'][1], gene_name])
        
    pvals.append(pvals_per_cluster)
    

In [None]:
gene_scores = []
for k , pvals_per_cluster in enumerate(pvals):
    gene_score = pd.DataFrame(pvals_per_cluster,columns=['pval','estimate','gene'])
    gene_scores.append(gene_score)
  

In [None]:
genes_pvalues = pd.concat(gene_scores, axis=0).groupby('gene').min()

In [None]:
for top in [5, 10 , 50 , 100]:
    genes_top = genes_pvalues.sort_values('pval').index[:top].values
    overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
    print("DEGs for top {}: {}".format(top, len(overlap)))

In [None]:
for top in [50, 100 , 200]:
    genes_top = genes_pvalues.sort_values('pval').index[:top].values
#     overlap = np.intersect1d(psygenet[mask1|mask2]['Gene_Symbol'].values, genes_top)
    overlap_dis = np.intersect1d(disgenet[mask1_dis]['geneSymbol'].values, genes_top)
    print("DEGs for top {}: {}".format(top, len(overlap_dis)))
#     print(overlap)
#     print(overlap_dis)

# UMAP

In [None]:
# import warnings
# warnings.filterwarnings('ignore')
sc.pp.neighbors(adata, use_rep="X_scGAN", n_neighbors=15)
sc.tl.umap(adata, min_dist=0.1)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.umap(adata, color=["cell_type"], ax=ax, show=show_plot)
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.umap(adata, color=["batch"], ax=ax, show=show_plot)
