In [6]:
import keras
from keras import backend as K
from keras.layers import Dense, Input, merge
from keras.engine.topology import Layer
import numpy as np
from tensorflow.contrib.distributions import Categorical, Mixture, MultivariateNormalDiag
import tensorflow as tf
import math

In [27]:
class MixtureDensity(Layer):
    def __init__(self, output_dim, num_mix, **kwargs):
        self.output_dim = output_dim
        self.num_mix = num_mix
        with tf.name_scope('MDNLayer'):
            self.mdn_mus     = Dense(self.num_mix * self.output_dim, name='mdn_mus') # no activation
            self.mdn_sigmas  = Dense(self.num_mix, activation=K.exp, name='mdn_sigmas') # no activation 
            self.mdn_pi      = Dense(self.num_mix, activation=K.softmax, name='mdn_pi') # no activation
        super(MixtureDensity, self).__init__(**kwargs)

    def build(self, input_shape):
        self.mdn_mus.build(input_shape)
        self.mdn_sigmas.build(input_shape)
        self.mdn_pi.build(input_shape)
        self.trainable_weights = self.mdn_mus.trainable_weights + self.mdn_sigmas.trainable_weights + self.mdn_pi.trainable_weights
        self.non_trainable_weights = self.mdn_mus.non_trainable_weights + self.mdn_sigmas.non_trainable_weights + self.mdn_pi.non_trainable_weights
        self.built = True

    def call(self, x, mask=None):
        m = self.mdn_mus(x)
        s = self.mdn_sigmas(x)
        p = self.mdn_pi(x)
        
        with tf.name_scope('MDNLayer'):
            mdn_out = keras.layers.concatenate([m, s, p], name='mdn_out')
        return mdn_out
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    
    def get_config(self):
        config = {'output_dim': self.output_dim,                                    
                  'num_mix': self.num_mix}
        base_config = super(MDN, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def get_loss_func(self):
        def multivariate_loss(y_true, y_pred):
            mix = tf.range(start = 0, limit = self.num_mix)            
            out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[self.num_mix * self.output_dim, self.num_mix, self.num_mix], axis=1, name='mdn_coef_split')
                tf.distributions.MultivariateNormalDiag

        def unigaussian_loss(y_true, y_pred):
            mix = tf.range(start = 0, limit = self.num_mix)            
            out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[self.num_mix * self.output_dim, self.num_mix, self.num_mix], axis=1, name='mdn_coef_split')
            
            def loss_i(i):         
                batch_size = tf.shape(out_sigma)[0]
                sigma_i = tf.slice(out_sigma, [0, i], [batch_size, 1], name='mdn_sigma_slice')
                pi_i = tf.slice(out_pi, [0, i], [batch_size, 1], name='mdn_pi_slice')        
                mu_i = tf.slice(out_mu, [0, i * self.output_dim], [batch_size, self.output_dim], name='mdn_mu_slice')
                dist = tf.distributions.Normal(loc=mu_i, scale=sigma_i)
                loss = dist.prob(y_true)
                loss = pi_i * loss
                return loss

            result = tf.map_fn(lambda  m: loss_i(m), mix, dtype=tf.float32, name='mix_map_fn')
            
            result = tf.reduce_sum(result, axis=0, keepdims=False)
            result = -tf.log(result)
            result = tf.reduce_mean(result)
            return result
        

        with tf.name_scope('MDNLayer'):
            return unigaussian_loss


In [None]:
from tensorflow.contrib.distributions import Categorical, Mixture, MultivariateNormalDiag

def get_mixture_loss_func(num_mixes, output_dim):
    """Construct a loss functions for the MDN layer parametrised by number of mixtures."""
    def loss_func(y_true, y_pred):
        mix = tf.range(start = 0, limit = self.num_mix)
        mus, sigmas, pis = tf.split(value=y_pred, num_or_size_splits=[num_mixes * output_dim, num_mixes, num_mixes], axis=1, name="mixture_split")
        cat = Categorical(logits=pis)
        locs = 
        
    with tf.name_scope('MDNLayer'):
        return loss_func
    
    

In [28]:
model = keras.Sequential()
model.add(keras.layers.LSTM(64, batch_input_shape=(None,10,3), return_sequences=True))
model.add(keras.layers.LSTM(64))
m = MixtureDensity(3, 5)
model.add(m)
model.compile(loss=m.get_loss_func(), optimizer=keras.optimizers.Adam())
model.summary()

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_25 (LSTM)               (None, 10, 64)            17408     
_________________________________________________________________
lstm_26 (LSTM)               (None, 64)                33024     
_________________________________________________________________
mixture_density_12 (MixtureD (None, 3)                 1625      
Total params: 52,057
Trainable params: 52,057
Non-trainable params: 0
_________________________________________________________________
