In [2]:
import keras
from keras import backend as K
from keras.layers import Dense, Input, merge
from keras.engine.topology import Layer
import numpy as np
from tensorflow.contrib.distributions import Categorical, Mixture, MultivariateNormalDiag
import tensorflow as tf
import math

Using TensorFlow backend.


In [9]:
class MixtureDensity(Layer):
    def __init__(self, output_dim, num_mix, **kwargs):
        self.output_dim = output_dim
        self.num_mix = num_mix
        with tf.name_scope('MDNLayer'):
            self.mdn_mus     = Dense(self.num_mix * self.output_dim, name='mdn_mus') # mix*output vals, no activation
            self.mdn_sigmas  = Dense(self.num_mix * self.output_dim, activation=K.exp, name='mdn_sigmas') # mix*output vals exp activation
            self.mdn_pi      = Dense(self.num_mix, activation=K.softmax, name='mdn_pi') # mix vals, softmax
        super(MixtureDensity, self).__init__(**kwargs)

    def build(self, input_shape):
        self.mdn_mus.build(input_shape)
        self.mdn_sigmas.build(input_shape)
        self.mdn_pi.build(input_shape)
        self.trainable_weights = self.mdn_mus.trainable_weights + self.mdn_sigmas.trainable_weights + self.mdn_pi.trainable_weights
        self.non_trainable_weights = self.mdn_mus.non_trainable_weights + self.mdn_sigmas.non_trainable_weights + self.mdn_pi.non_trainable_weights
        self.built = True

    def call(self, x, mask=None):
        m = self.mdn_mus(x)
        s = self.mdn_sigmas(x)
        p = self.mdn_pi(x)
        
        with tf.name_scope('MDNLayer'):
            mdn_out = keras.layers.concatenate([m, s, p], name='mdn_out')
        return mdn_out
        
    def compute_output_shape(self, input_shape):
        return (input_shape[0], self.output_dim)
    
    def get_config(self):
        config = {'output_dim': self.output_dim,                                    
                  'num_mix': self.num_mix}
        base_config = super(MDN, self).get_config()
        return dict(list(base_config.items()) + list(config.items()))

    def get_loss_func(self):
        def unigaussian_loss(y_true, y_pred):
            mix = tf.range(start = 0, limit = self.num_mix)            
            out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[self.num_mix * self.output_dim, self.num_mix, self.num_mix], axis=1, name='mdn_coef_split')
            
            def loss_i(i):         
                batch_size = tf.shape(out_sigma)[0]
                sigma_i = tf.slice(out_sigma, [0, i], [batch_size, 1], name='mdn_sigma_slice')
                pi_i = tf.slice(out_pi, [0, i], [batch_size, 1], name='mdn_pi_slice')        
                mu_i = tf.slice(out_mu, [0, i * self.output_dim], [batch_size, self.output_dim], name='mdn_mu_slice')
                dist = tf.distributions.Normal(loc=mu_i, scale=sigma_i)
                loss = dist.prob(y_true)
                loss = pi_i * loss
                return loss

            result = tf.map_fn(lambda  m: loss_i(m), mix, dtype=tf.float32, name='mix_map_fn')
            
            result = tf.reduce_sum(result, axis=0, keepdims=False)
            result = -tf.log(result)
            result = tf.reduce_mean(result)
            return result

        with tf.name_scope('MDNLayer'):
            return unigaussian_loss


In [37]:
def get_mixture_loss_func(output_dim, num_mixes):
    """Construct a loss functions for the MDN layer parametrised by number of mixtures."""
    
    # Construct a loss function with the right number of mixtures and outputs
    def loss_func(y_true, y_pred):
        out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, 
                                                                         num_mixes * output_dim, 
                                                                         num_mixes], 
                                             axis=1, name='mdn_coef_split')
        cat = Categorical(logits=out_pi)
        component_splits = [output_dim] * num_mixes
        mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
        sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
        coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
            in zip(mus, sigs)]
        mixture = Mixture(cat=cat, components=coll)
        loss = mixture.log_prob(y_true)
        loss = tf.negative(loss)
        #         loss = reduce_mean(loss)
        print("y_pred:", y_pred.shape)
        print("mu:",  out_mu.shape)
        print("sigma:", out_sigma.shape)
        print("pi:", out_pi.shape)
        print("splits:", component_splits)
        print("Mix:", mixture)
        return loss
    
    # Actually return the loss_func
    with tf.name_scope('MDNLayer'):
        return loss_func
    

In [39]:
OUTPUT_DIMENSION = 3
NUMBER_MIXTURES = 5
LSTM_SIZE = 64

model = keras.Sequential()
model.add(keras.layers.LSTM(LSTM_SIZE, batch_input_shape=(None,10,OUTPUT_DIMENSION), return_sequences=True))
model.add(keras.layers.LSTM(LSTM_SIZE))
model.add(MixtureDensity(OUTPUT_DIMENSION, NUMBER_MIXTURES))
model.compile(loss=get_mixture_loss_func(OUTPUT_DIMENSION,NUMBER_MIXTURES), optimizer=keras.optimizers.Adam())
model.summary()

y_pred: (?, 35)
mu: (?, 15)
sigma: (?, 15)
pi: (?, 5)
splits: [3, 3, 3, 3, 3]
Mix: tf.distributions.Mixture("Mixture", batch_shape=(?,), event_shape=(3,), dtype=float32)
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_11 (LSTM)               (None, 10, 64)            17408     
_________________________________________________________________
lstm_12 (LSTM)               (None, 64)                33024     
_________________________________________________________________
mixture_density_6 (MixtureDe (None, 3)                 2275      
Total params: 52,707
Trainable params: 52,707
Non-trainable params: 0
_________________________________________________________________


In [35]:
# Testing creation of a mixture model loss function.

model = keras.Sequential()
model.add(keras.layers.LSTM(64, batch_input_shape=(None,10,3), return_sequences=True))
model.add(keras.layers.LSTM(64))
m = MixtureDensity(3, 5)
model.add(m)
# model.compile(loss=m.get_loss_func(), optimizer=keras.optimizers.Adam())
model.summary()

num_mixes = 5
output_dim = 3
y_pred = model.output
# Start messing around with model.output.
out_mu, out_sigma, out_pi = tf.split(y_pred, num_or_size_splits=[num_mixes * output_dim, num_mixes * output_dim, num_mixes], axis=1, name='mdn_coef_split')
cat = Categorical(logits=out_pi)
component_splits = [output_dim] * num_mixes
mus = tf.split(out_mu, num_or_size_splits=component_splits, axis=1)
sigs = tf.split(out_sigma, num_or_size_splits=component_splits, axis=1)
coll = [MultivariateNormalDiag(loc=loc, scale_diag=scale) for loc, scale
    in zip(mus, sigs)]
mixture = Mixture(cat=cat, components=coll)

print("y_pred:", y_pred.shape)
print("mu:",  out_mu.shape)
print("sigma:", out_sigma.shape)
print("pi:", out_pi.shape)
print("splits:", component_splits)
print("Mix:", mixture)




# MultivariateNormalDiag(loc=out_mu, scale_diag=out_sigma)

_________________________________________________________________
Layer (type)                 Output Shape              Param #   
lstm_7 (LSTM)                (None, 10, 64)            17408     
_________________________________________________________________
lstm_8 (LSTM)                (None, 64)                33024     
_________________________________________________________________
mixture_density_4 (MixtureDe (None, 3)                 2275      
Total params: 52,707
Trainable params: 52,707
Non-trainable params: 0
_________________________________________________________________
y_pred: (?, 35)
mu: (?, 15)
sigma: (?, 15)
pi: (?, 5)
splits: [3, 3, 3, 3, 3]
Mix: tf.distributions.Mixture("Mixture", batch_shape=(?,), event_shape=(3,), dtype=float32)
