In [None]:
# import warnings
# warnings.filterwarnings('ignore')

import random
from pathlib import Path
from optparse import OptionParser
import numpy as np
import pandas as pd
from sklearn.preprocessing import MinMaxScaler, StandardScaler, RobustScaler
from sklearn.decomposition import PCA
from sklearn.neighbors import KNeighborsClassifier
from sklearn.manifold import TSNE

import keras
from keras.utils import multi_gpu_model
from keras.models import Model, Sequential
from keras.layers import Input, Dense, Dropout, Lambda, Activation, Concatenate, BatchNormalization, ReLU
from keras.optimizers import Adam
from keras import backend as K
from keras.utils import to_categorical
from keras.losses import mse

import tensorflow as tf

import matplotlib.pyplot as plt
# %matplotlib inline

import os
from scvi.dataset import AnnDatasetFromAnnData, RetinaDataset, LoomDataset
import scanpy as sc
import anndata
import matplotlib.pyplot as plt
import sys

sys.path.append("/home/mcb/users/mbahra5/project/")
from utils import entropy_batch_mixing , clustering_scores, classification_acc_measure


# Hyper Parameters

In [None]:
seed = 344
gpus = ["2"]
decoder_loss_weight = 1.
kl_loss_weight = 0.0
distance_loss_weight = 0.0 
margin = 1.0
epochs = 50
batch_size = 128
layers = [128]
latent_dim = 10
activation = 'relu'
latent_activation = 'linear'
output_activation = 'relu'
split_train , split_val = 0.9, 0.95
verbose = 1

In [None]:
# parser = OptionParser()

# parser.add_option("--seed", dest="seed", type="int", default=342)
# parser.add_option("--gpus", dest="gpus", type="string", default='["2"]')
# parser.add_option("--epochs", dest="epochs", type="int", default=50)
# parser.add_option("--batch_size", dest="batch_size", type="int", default=128)
# parser.add_option("--loss_weight", dest="loss_weight", type="float", default=0.)
# parser.add_option("--kl_loss_weight", dest="kl_loss_weight", type="float", default=1.)
# parser.add_option("--margin", dest="margin", type="float", default=1.0)
# parser.add_option("--layers", dest="layers", type="string", default='[128]')
# parser.add_option("--latent_dim", dest="latent_dim", type="int", default=10)
# parser.add_option("--latent_activation", dest="latent_activation", type="string", default='linear')
# parser.add_option("--verbose", dest="verbose", type="int", default=0)

# (options, args) = parser.parse_args()

# seed = options.seed
# gpus = eval(options.gpus)
# decoder_loss_weight = 1.
# distance_loss_weight = options.loss_weight
# kl_loss_weight = options.kl_loss_weight
# margin = options.margin
# epochs = options.epochs
# batch_size = options.batch_size
# layers = eval(options.layers)
# latent_dim = options.latent_dim
# activation = 'relu'
# latent_activation = options.latent_activation
# output_activation = 'relu'
# split_train , split_val = 0.9, 0.95
# verbose = options.verbose


# print(options)

# Initialization

In [None]:
os.environ['PYTHONHASHSEED']=str(seed)
random.seed(seed)
np.random.seed(seed)
tf.set_random_seed(seed)

os.environ["CUDA_VISIBLE_DEVICES"] = ','.join(gpus)

# Load Data

In [None]:
batch_col, batch_col_cat = 'batch','batch_cat'
save_path = "/home/mcb/users/mbahra5/project/data/"
dataset = RetinaDataset(save_path=save_path)

In [None]:
# dataset.filter_genes_

In [None]:
adata = anndata.AnnData(X=dataset.X)
adata.obs['cell_type'] = np.array([dataset.cell_types[dataset.labels[i][0]] for i in range(adata.n_obs)])
adata.obs[batch_col] = np.array([dataset.batch_indices[i][0] for i in range(adata.n_obs)])

# Preprocess

In [None]:
sc.pp.subsample(adata,fraction=1,random_state = seed)
sc.pp.log1p(adata)
# sc.pp.scale(adata)

adata.obs[batch_col_cat] = pd.Categorical(adata.obs[batch_col])

In [None]:
adata_train = adata[:int(split_train*adata.n_obs)]
adata_val   = adata[int(split_train *adata.n_obs) : int(split_val*adata.n_obs)]
adata_test  = adata[int(split_val*adata.n_obs) :]

# Latent Inference

In [None]:
def sampling(args):
    z_mean, z_log_var = args
    batch = K.shape(z_mean)[0]
    dim = K.int_shape(z_mean)[1]
    # by default, random_normal has mean = 0 and std = 1.0
    epsilon = K.random_normal(shape=(batch, dim))
    return z_mean + K.exp(0.5 * z_log_var) * epsilon

def create_encoder_network(input_shape, layers, latent_dim):
    inputs = Input(shape=input_shape, name='encoder_input')
    x = inputs
    for layer in layers:
        if layer == 'D':
            x = Dropout(0.2)(x)
        else:
            x = Dense(units=layer)(x)
            x = BatchNormalization(momentum=0.99, epsilon=0.001)(x) # todo: momentum for scVI paper is 0.01
            x = Activation(activation)(x)
            x = Dropout(0.1)(x)
    
    
    z_mean = Dense(latent_dim, name='z_mean')(x)
    z_log_var = Dense(latent_dim, name='z_log_var')(x)
    z = Lambda(sampling, output_shape=(latent_dim,), name='z')([z_mean, z_log_var])
    #     x = Dense(units=latent_dim, activation=latent_activation)(x)
    
    encoder = Model(inputs, [z_mean, z_log_var, z], name='encoder')
    return encoder


def create_decoder_network(layers, latent_dim, output_shape, batch_input_shape):
    
    batch_input = Input(shape=batch_input_shape, name='batch_input')
    latent_inputs = Input(shape=(latent_dim,), name='decoder_input')
    x = Concatenate(axis=-1)([batch_input, latent_inputs])

    for layer in reversed(layers[:-1]):
        x = Dense(units=layer)(x)
        x = BatchNormalization(momentum=0.99, epsilon=0.001)(x) # todo: momentum for scVI paper is 0.01
        x = Activation(activation)(x)
        
    outputs = Dense(units=output_shape[0], activation=output_activation)(x)

    # Instantiate Decoder Model
    decoder = Model([batch_input, latent_inputs], outputs, name='decoder')

    return decoder


def create_all_models(input_shape, batch_input_shape, layers, latent_dim):
    encoder = create_encoder_network(input_shape, layers, latent_dim)

    batch_input_1 = Input(shape=batch_input_shape, name='batch_input_1')
    input_1 = Input(shape=input_shape, name='input_1')
    input_2 = Input(shape=input_shape, name='input_2')
    isSameBatch = Input(shape=(1,), name='isSameBatch')
    
    
    z_mean, z_log_var, em1_layer = encoder(input_1)
    em2_layer = encoder(input_2)[2]

    decoder = create_decoder_network(layers, latent_dim, input_shape, batch_input_shape)
    distance_layer = Lambda(euclidean_distance, output_shape=eucl_dist_output_shape, name='distance_layer')(
        [em1_layer, em2_layer])
    decoder_layer = decoder([batch_input_1, em1_layer])

    model = Model(inputs=[input_1, input_2,batch_input_1,isSameBatch], outputs=[distance_layer, decoder_layer])
    if len(gpus)>1:
        model = multi_gpu_model(model, gpus=len(gpus)) 
    
    reconstruction_loss = mse(input_1, decoder_layer)
    original_dim = input_shape[0]
    kl_loss = 1 + z_log_var - K.square(z_mean) - K.exp(z_log_var)
    kl_loss = -0.5 * K.sum(kl_loss, axis=-1) / original_dim
    
    batch_corr_loss = contrastive_loss(isSameBatch, distance_layer)
    
    vae_loss = K.mean(reconstruction_loss + kl_loss_weight * kl_loss + distance_loss_weight * batch_corr_loss)
    model.add_loss(vae_loss)
    
    model.compile(optimizer='adam')

    # embedding_preds = embedding_network.predict(x, batch_size=64)
    # Wsave = model.get_weights()

    return model, encoder, decoder


def euclidean_distance(vects):
    x, y = vects
    sum_square = K.sum(K.square(x - y), axis=1, keepdims=True)
    return K.sqrt(K.maximum(sum_square, K.epsilon()))


def eucl_dist_output_shape(shapes):
    shape1, shape2 = shapes
    return (shape1[0], 1)


def contrastive_loss(y_true, y_pred):
    '''Contrastive loss from Hadsell-et-al.'06
    http://yann.lecun.com/exdb/publis/pdf/hadsell-chopra-lecun-06.pdf
    '''
    sqaure_pred = K.square(y_pred)
    margin_square = K.square(K.maximum(margin - y_pred, 0))
    return K.mean((1 - y_true) * sqaure_pred + (y_true) * margin_square)


In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
sess = tf.Session(config=config)
K.set_session(sess)

batch_input_shape = (len(adata_train.obs[batch_col].unique()),)
input_shape = (adata_train.X.shape[1],)

model, encoder, decoder = create_all_models(input_shape, batch_input_shape, layers, latent_dim)
# encoder = create_encoder_network(input_shape, layers, latent_dim)
# decoder = create_decoder_network(layers, latent_dim, input_shape)
# model = Model(encoder.inputs, decoder(encoder(encoder.inputs)), name='autoencoder')
# model.compile(loss='mse', optimizer='adam')
model.summary()


In [None]:
def batch_gen(adata_x):
    while True:
        for i in range(0, adata_x.X.shape[0] - batch_size, batch_size):
            batch_1 = adata_x.obs[i:i + batch_size][batch_col_cat].cat.codes.values
#             sample_1 = adata_x.X[i:i + batch_size].toarray()
            sample_1 = adata_x.X[i:i + batch_size]
            batch_1_onehot = to_categorical(batch_1,num_classes=batch_input_shape[0])
#             batch_1_onehot = np.zeros((batch_1.shape[0],batch_input_shape[0]))
            
            j = random.randrange(0, adata_x.X.shape[0] - batch_size)
            batch_2 = adata_x.obs[j:j + batch_size][batch_col_cat].cat.codes.values
#             sample_2 = adata_x.X[j:j + batch_size].toarray()
            sample_2 = adata_x.X[j:j + batch_size]

            pair_y = batch_1 == batch_2

            yield {'input_1': sample_1, 'input_2': sample_2, 'batch_input_1': batch_1_onehot, 'isSameBatch': pair_y}, {}
#             yield sample_1, sample_1
#             yield sample_1.max()
            
steps_per_epoch = adata_train.X.shape[0] // batch_size
steps_per_epoch_val = adata_val.X.shape[0] // batch_size


class ShuffleData(keras.callbacks.Callback):
    def on_epoch_begin(self, epoch, logs={}):
        sc.pp.subsample(adata_train,fraction=1,random_state = seed)

In [None]:
history = model.fit_generator(batch_gen(adata_train), epochs=epochs, steps_per_epoch=steps_per_epoch,
                              callbacks = [ShuffleData()], verbose = verbose,
                              validation_data=batch_gen(adata_val), validation_steps=steps_per_epoch_val)

In [None]:
train_history = history.history['loss']
val_history = history.history['val_loss']
x = np.array(range(len(train_history)))
plt.plot(x, train_history,'blue')
plt.plot(x, val_history, 'red')
# plt.ylim(min(train_history)-0.01, 1)

In [None]:
latent = encoder.predict(adata_train.X, batch_size=batch_size)[0]
labels = adata_train.obs['cell_type']
batches = adata_train.obs[batch_col][:,None]

In [None]:
adata_train.obsm["X_DPFE"] = latent

# Scores

In [None]:
def calc_scores(input_adata):
    latent = encoder.predict(input_adata.X, batch_size=batch_size)[0]
    labels = input_adata.obs['cell_type']
    batches = input_adata.obs[batch_col][:,None]
    print("Entropy of batch mixing :", entropy_batch_mixing(latent,batches))
    print("Clustering ARI = {}".format(clustering_scores(dataset.n_labels, labels, latent)))

In [None]:
print('Train Set:')
calc_scores(adata_train)

In [None]:
print('Validation Set:')
calc_scores(adata_val)

In [None]:
print('Test Set:')
calc_scores(adata_test)

In [None]:
print('Totoal Data Set:')
calc_scores(adata)

In [None]:
# posterior.clustering_scores()

# t-SNE

In [None]:
sc.tl.tsne(adata_train, use_rep='X_DPFE', n_pcs=2)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(8, 7))
sc.pl.tsne(adata_train, color=["cell_type"], ax=ax, show=show_plot)
fig, ax = plt.subplots(figsize=(8, 7))
sc.pl.tsne(adata_train, color=[batch_col], ax=ax, show=show_plot)

# UMAP

In [None]:
import warnings
warnings.filterwarnings('ignore')
sc.pp.neighbors(adata_train, use_rep="X_DPFE", n_neighbors=15)
sc.tl.umap(adata_train, min_dist=0.1)

In [None]:
show_plot = True
fig, ax = plt.subplots(figsize=(7, 6))
sc.pl.umap(adata_train, color=["cell_type"], ax=ax, show=show_plot)
fig, ax = plt.subplots(figsize=(7, 6))
sc.pl.umap(adata_train, color=[batch_col], ax=ax, show=show_plot)


# Test

In [None]:
def add_latent(adata_x):
    latent = encoder.predict(adata_x.X, batch_size=batch_size)
    adata_x.obsm["X_DPFE"] = latent

In [None]:
add_latent(adata_val)
sc.tl.tsne(adata_val, use_rep='X_DPFE', n_pcs=2)

fig, ax = plt.subplots(figsize=(8, 7))
sc.pl.tsne(adata_val, color=["cell_type"], ax=ax, show=True)
# fig, ax = plt.subplots(figsize=(8, 7))
# sc.pl.tsne(adata_val, color=[batch_col], ax=ax, show=True)

In [None]:
add_latent(adata_test)
sc.tl.tsne(adata_test, use_rep='X_DPFE', n_pcs=2)

fig, ax = plt.subplots(figsize=(8, 7))
sc.pl.tsne(adata_test, color=["cell_type"], ax=ax, show=True)
# fig, ax = plt.subplots(figsize=(8, 7))
# sc.pl.tsne(adata_val, color=[batch_col], ax=ax, show=True)

In [None]:
meta = adata.obs
meta.groupby(['cell_type','batch']).aggregate('count')

In [None]:
from keras.utils import plot_model
plot_model(encoder, to_file='vae_mlp_encoder.png', show_shapes=True)

# Classification Acc Measure

In [None]:
adata_x = adata_train
latent_x = encoder.predict(adata_x.X, batch_size=batch_size)[0]
labels_x = adata_x.obs['cell_type'].astype('category').cat.codes


In [None]:
classification_acc_measure(latent_x, labels_x)