In [None]:
from __future__ import print_function

import numpy as np
import matplotlib.pyplot as plt
from scipy.stats import norm

from keras.layers import Input, Dense, Lambda
from keras.models import Model
from keras import backend as K
from keras import metrics
from keras.datasets import mnist

##########################################################################
# vae
# 
# This function is heavily derived from - 
# 
# Inputs:
#   - Y: matrix with gene names as counts and individual cells are rows
#   - loss_fun: loss function for autoencoder
#   - latent_dim: number of latent dimensions to project data into
#   - intermediate_dim: number of dimensions in the hidden layer
#   - batch_size: batch size used for training autoencoder
#   - epochs: epochs for training autoencoder
#   - nbshape: shape parameter for negative binomial loss. Defaults to 1
# 
# Outputs:
#   - x_encoded: Encoded transformation of input
#   - z_mean: 
#
##########################################################################

def vae(Y, loss_fun, latent_dim, intermediate_dim, 
        batch_size, epochs, epsilon_std, nbshape=1):
            
    x_train = preprocess(Y, loss_fun)
    
    original_dim = x_train.shape[1]
    
    x = Input(shape=(original_dim,))
    h = Dense(intermediate_dim, activation='relu')(x)
    z_mean = Dense(latent_dim)(h)
    z_log_var = Dense(latent_dim)(h)

    # note that "output_shape" isn't necessary with the TensorFlow backend
    z = Lambda(sampling, output_shape=(latent_dim,))([z_mean, z_log_var])

    # we instantiate these layers separately so as to reuse them later
    decoder_h = Dense(intermediate_dim, activation='relu')
    decoder_mean = Dense(original_dim, activation='sigmoid')
    h_decoded = decoder_h(z)
    x_decoded_mean = decoder_mean(h_decoded)

    # instantiate VAE model
    model = Model(x, x_decoded_mean)

    # Compute VAE loss
    
    model_loss = compute_model_loss(loss_fun, x, x_decoded_mean, nbshape)
    kl_loss = - 0.5 * K.sum(1 + z_log_var - K.square(z_mean) - K.exp(z_log_var), axis=-1)
    vae_loss = K.mean(model_loss + kl_loss)

    model.add_loss(vae_loss)
    model.compile(optimizer='rmsprop')
    model.summary()
    
    # fit the model with the given data
    model.fit(x_train,
        shuffle=True,
        epochs=epochs,
        batch_size=batch_size,
        validation_split=0.2)

    # build a model to project inputs on the latent space
    encoder = Model(x_train, z_mean)
    x_encoded = encoder.predict(x_train, batch_size=batch_size)
    
    return x_encoded, z_mean

##########################################################################
# preprocess
#
# Inputs:
#   - Y: matrix with gene names as counts and individual cells are rows
#   - loss_fun: loss function for autoencoder
#
# Outputs:
#   - x_train: preprocessed matrix of counts for training the autoencoder
##########################################################################

def preprocess(Y, loss_fun):
    if 
    
    return x_train

##########################################################################
# sampling
#
# Inputs:
#   -
# 
# Outputs:
#   -
##########################################################################

def sampling(args):
    z_mean, z_log_var = args
    epsilon = K.random_normal(shape=(K.shape(z_mean)[0], latent_dim), mean=0.,
                              stddev=epsilon_std)
    return z_mean + K.exp(z_log_var / 2) * epsilon

##########################################################################
# compute_model_loss
# 
# Inputs:
#   - loss_fun: loss function for autoencoder
#   - x: preprocessed input to the autoencoder
#   - x_decoded_mean: mean of the x decoded by the autoencoder
#   - nbshape: shape parameter for negative binomial loss. Defaults to 1
# 
# Outputs:
#   - loss: loss from the autoencoding
##########################################################################

def compute_model_loss(loss_fun, x, x_decoded_mean, nbshape):
    if loss_fun == 'poisson':
        loss = original_dim*poisson(x, x_decoded_mean)
    elif loss_fun == 'negative_binomial':
        loss = original_dim*negative_binomial(x, x_decoded_mean, nbshape)
    elif loss_fun == 'gaussian':
        loss = original_dim*gaussian(x, x_decoded_mean)
    else:
        loss = original_dim * metrics.binary_crossentropy(x, x_decoded_mean)
        
    return loss