In [None]:
# import warnings
# warnings.filterwarnings('ignore')

import GEOparse
from tqdm import tqdm
import urllib.request
import random
import os
import numpy as np
import pandas as pd
import scanpy as sc
import anndata

from scvi.dataset import AnnDatasetFromAnnData

import torch
import matplotlib.pyplot as plt
import tensorflow as tf
import sys
import umap

from utils_helper import *

seed = 345
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
tf.set_random_seed(seed)


gpus = ["1"]
os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)
device = 'cuda:0'

# Load Data

In [None]:
data = LoomDataset('path_to_dataset.loom')
celltypes = pd.read_csv('/home/mcb/users/mbahra5/project/data/panc8/celltypes.csv')
tech = pd.read_csv('/home/mcb/users/mbahra5/project/data/panc8/tech.csv')

In [None]:
celltypes = celltypes['x']
tech = tech['x']

In [None]:
adata = anndata.AnnData(X=data.X)

adata.obs['cell_types'] = celltypes.values
adata.obs['cell_type'] = celltypes.values

adata.obs['batch_name'] = tech.values

In [None]:
adata = adata[adata.obs['batch_name']!='indrop'].copy()

In [None]:
adata.obs['labels'] = adata.obs['cell_types'].astype('category').cat.codes.values

adata.obs['batch'] = adata.obs['batch_name'].astype('category').cat.codes.values
adata.obs['batch_indices'] = adata.obs['batch'].values


n_labels = len(adata.obs['cell_types'].unique())
n_batch = len(adata.obs['batch'].unique())

In [None]:
sc.pp.normalize_total(adata)

In [None]:
sc.pp.subsample(adata,fraction=1,random_state = seed)

In [None]:
dataset = AnnDatasetFromAnnData(adata)

# Preprocess

In [None]:
# sc.pp.log1p(adata)

# Latent Inference

In [None]:
n_epochs = 50
lr = 0.001
eps = 1e-8
use_batches = True
use_cuda = True
n_latent = 10
%matplotlib inline

In [None]:
vae = VAE(dataset.nb_genes, n_batch=dataset.n_batches * use_batches, n_latent=n_latent, n_layers = 2, n_hidden=64)

In [None]:
disc = Discriminator(n_latent, [2*n_latent], n_batch).to('cuda:0')

In [None]:
trainer = GANTrainer(
    vae, 
    disc,
    dataset,
    train_size=0.9999999999999,  test_size=None,
    use_cuda=use_cuda,
    frequency=5,
    seed = seed
)


In [None]:
history = trainer.train(n_epochs=20, lr= lr*1, eps=eps, disc_lr= lr * 1, enc_lr = lr* 0.0)

In [None]:
history = trainer.train(n_epochs=50, lr= lr*1, eps=eps, disc_lr= lr * 1, enc_lr = lr* 0.05)

In [None]:
elbo_train = history[0]
x = np.linspace(0, len(elbo_train), len(elbo_train))
plt.plot(x, elbo_train)

In [None]:
posterior = trainer.create_posterior(trainer.model, dataset, indices=np.arange(len(dataset)))
latent, batches, labels = posterior.sequential().get_latent()

In [None]:

adata.obsm["X_scVI"] = latent
latent_tensor = torch.tensor(latent, device='cuda:0')
batch_prediction = np.argmax(disc(latent_tensor).detach().cpu().numpy(), axis=1)

# Clustering Scores by kmeans

In [None]:
def calc_scores(input_posterior):
    latent, batches, labels = input_posterior.sequential().get_latent()
    print("Entropy of batch mixing :", entropy_batch_mixing(latent,batches))
    print("Clustering ARI = {}".format(clustering_scores(dataset.n_labels, labels, latent)))

In [None]:
print('Train Set:')
calc_scores(trainer.train_set)

In [None]:
print('Test Set:')
calc_scores(trainer.test_set)

In [None]:
print('Validation Set:')
calc_scores(trainer.validation_set)

# t-SNE

In [None]:
sc.tl.tsne(adata, use_rep='X_scVI', n_pcs=2)

In [None]:
adata.obs['batch_cat'] = adata.obs['batch'].astype('category')
adata.obs['batch_pred_cat'] = pd.Series(batch_prediction).astype('category').values

In [None]:
dataset_name = 'Human_Pancreatic_Cells'
# method = 'scGAN'
method = '$scGAN^{-}$(No Adversarial Net)'

fig, ax = plt.subplots(figsize=(7, 6), dpi=150)
sc.pl.tsne(adata, color=["cell_types"], ax=ax, title='{} - Cell Type'.format(method))
fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_celltype.png'.format(dataset_name,method), bbox_inches = 'tight')

fig, ax = plt.subplots(figsize=(7, 6),  dpi=150)
sc.pl.tsne(adata, color=["batch_name"], ax=ax, title='{} - Batch'.format(method))
fig.savefig('/home/mcb/users/mbahra5/project/scVI/pics/{}_{}_batch.png'.format(dataset_name,method), bbox_inches = 'tight')


In [None]:
sc.pp.neighbors(adata, use_rep="X_scVI", n_neighbors=30)
sc.tl.louvain(adata, resolution=0.20)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(9, 8))
sc.pl.tsne(adata, color=['louvain'], ax=ax, show=show_plot)

In [None]:
from sklearn.metrics import adjusted_rand_score as ARI
ari_score = ARI(labels, adata.obs['louvain'])
print(ari_score)

# UMAP

In [None]:
# import warnings
# warnings.filterwarnings('ignore')
sc.pp.neighbors(adata, use_rep="X_scVI", n_neighbors=15)
sc.tl.umap(adata, min_dist=0.1)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.umap(adata, color=["cell_type"], ax=ax, show=show_plot)
fig, ax = plt.subplots(figsize=(10, 9))
sc.pl.umap(adata, color=["batch_name"], ax=ax, show=show_plot)


In [None]:
from scvi.inference import UnsupervisedTrainer
import time
import logging
import sys
import time
from tqdm import trange
from scvi.inference.posterior import Posterior
logger = logging.getLogger(__name__)

class GANTrainer(UnsupervisedTrainer):
    def __init__(
        self,
        model,
        disc,
        gene_dataset,
        train_size, test_size,
        **kwargs
    ):
        self.disc = disc
                 
        super().__init__(model, gene_dataset, train_size=train_size, test_size=test_size, **kwargs)
        if type(self) is GANTrainer:
            self.train_set, self.test_set, self.validation_set = self.train_test_validation(
                model, gene_dataset, train_size, test_size
            )
        
    def train(self, n_epochs=20, lr=1e-3, eps=0.01, params=None, enc_lr=1e-3, disc_lr=1e-3):
        begin = time.time()
        self.model.train()
        self.disc.train()

        if params is None:
            params = filter(lambda p: p.requires_grad, self.model.parameters())

        optimizer = self.optimizer = torch.optim.Adam(params, lr=lr, eps=eps, weight_decay=self.weight_decay)
        optimizerE = self.optimizerE = torch.optim.Adam(self.model.z_encoder.parameters(), lr = enc_lr, weight_decay=self.weight_decay)
        optimizerD = self.optimizerD = torch.optim.Adam(self.disc.parameters(), lr = disc_lr, weight_decay=self.weight_decay)
        
        self.n_epochs = n_epochs
        nll_loss = nn.NLLLoss(reduction='none') 
        kl_loss = nn.KLDivLoss()
        mse_loss = nn.MSELoss()

        with trange(n_epochs, desc="training", file=sys.stdout, disable=not self.show_progbar) as pbar:
            vae_loss_list, E_loss_list, D_loss_list = [], [], []
            for self.epoch in pbar:
                vae_loss_list_epoch, E_loss_list_epoch, D_loss_list_epoch = [], [], []
                
                pbar.update(1)
                self.on_epoch_begin()

    
                for tensors_list in self.data_loaders_loop():
                    if tensors_list[0][0].shape[0] < 3:
                        continue
                    
                    sample_batch, local_l_mean, local_l_var, batch_index, _ = tensors_list[0]  
                    ############################
                    # (1) Update VAE network
                    ###########################                    
                    self.model.zero_grad()
                        
                    reconst_loss, kl_divergence, z = self.model(sample_batch, local_l_mean, local_l_var, batch_index)
                    loss = torch.mean(reconst_loss + self.kl_weight * kl_divergence)
                    
                    vae_loss_list_epoch.append(loss.item())
                    loss.backward(retain_graph=True)
                    optimizer.step()
                    ############################
                    # (1) Update D Net
                    ###########################     
                    for disc_iter in range(10):
                        self.disc.zero_grad()

                        batch_pred = self.disc(z)
                        D_loss = nll_loss(batch_pred, batch_index.view(-1)) 
                        D_loss = torch.mean(D_loss) # todo
#                         D_loss = mse_loss(batch_pred, batch_index.view(-1))
                        D_loss_list_epoch.append(D_loss.item())
                        D_loss.backward(retain_graph=True)
                        optimizerD.step()
                    ############################
                    # (1) Update E Net
                    ########################### 
                    self.model.z_encoder.zero_grad()
                    E_loss = -1 * D_loss

                    E_loss_list_epoch.append(E_loss.item())
                    E_loss.backward(retain_graph=True)
                    optimizerE.step()
                    
                vae_loss_list.append(sum(vae_loss_list_epoch)/len(vae_loss_list_epoch))
                D_loss_list.append(sum(D_loss_list_epoch)/len(D_loss_list_epoch))
                E_loss_list.append(sum(E_loss_list_epoch)/len(E_loss_list_epoch))


        self.model.eval()
        return vae_loss_list, D_loss_list, E_loss_list
