In [118]:
import tensorflow as tf
# Implementing Variational RNN's and variations by subclassing Keras RNN-type Cells

class VRNNCell(tf.keras.layers.GRUCell):
    def __init__(self, units, **kwargs):
        super(VRNNCell, self).__init__(units, **kwargs)
    

    def build(self, input_shape):
        # Taking most of the standard weights from the base GRU class
        super().build((input_shape[0], input_shape[1] + self.units))
        self.encoder_mu_kernel = self.add_weight(shape=(input_shape[-1] + self.units, self.units),
                                                 initializer='uniform',
                                                 name='kernel')
        
        self.encoder_logvar_kernel = self.add_weight(shape=(input_shape[-1] + self.units, self.units),
                                      initializer='uniform',
                                      name='kernel')
        
        self.prior_mu_kernel = self.add_weight(shape=(self.units, self.units),
                              initializer='uniform',
                              name='kernel')
        
        self.prior_logvar_kernel = self.add_weight(shape=(self.units,self.units),
                                      initializer='uniform',
                                      name='kernel')

    def sample(self, mu, log_var):
        # Sample from unit Normal
        epsilon = tf.random.normal([1, self.units])
        half_constant = tf.convert_to_tensor(np.full((1, self.units), 0.5).astype('float32'))
        # All element-wise computations
        z = tf.math.multiply(half_constant, tf.math.exp(log_var)) + mu
        return z
    
    def call(self, inputs, states, training=False):
        # Some formulations:
        # Generation:
        # z_t ~ N(mu_(0, t), sigma_(0,t)), w here [mu_(0,t), sigma(0,t)] = phi_prior(h_(t-1))
        # Update: 
        # h_t = f_theta(h_(t-1), z_t, x_t) *recurrence equation
        # Inference:
        # z_t ~ N(mu_z, sigma_z), where [mu_z, sigma_z] = phi_post(x_t, h_(t-1))
        #
        # Let the base RNN cell handle the rest and add loss
        
        if training:
            x_t = inputs
            h_prev = states[0]

            p_mu = tf.matmul(h_prev, self.prior_mu_kernel)
            p_logvar = tf.matmul(h_prev, self.prior_logvar_kernel)
            
            input_state_concat = tf.concat([x_t, h_prev], axis=1)
            
            q_mu = tf.matmul(input_state_concat, self.encoder_mu_kernel)
            q_logvar = tf.matmul(input_state_concat, self.encoder_logvar_kernel)
            z_t = self.sample(q_mu, q_logvar)
            
            inp = tf.concat([x_t, z_t], axis=1)
            _, h_next = super().call(inp, h_prev)
            
            output = z_t
            new_state = h_next
            
            return output, [h_next]
        
        else:
            # Return prior and posterior parameters
            x_t = inputs
            h_prev = states[0]

            p_mu = tf.matmul(h_prev, self.prior_mu_kernel)
            p_logvar = tf.matmul(h_prev, self.prior_logvar_kernel)
            z_t = self.sample(p_mu, p_logvar)
            
            input_state_concat = tf.concat([x_t, h_prev], axis=1)
            
            q_mu = tf.matmul(input_state_concat, self.encoder_mu_kernel)
            q_logvar = tf.matmul(input_state_concat, self.encoder_logvar_kernel)
            
            
            i = tf.concat([x_t, z_t], axis=1)
            _, h_next = super().call(i, h_prev)
            
            output = (z_t, p_mu, p_logvar, q_mu, q_logvar)
            
            return output, [h_next]

    def get_config(self):
        return {"units":self.units}

In [119]:
cell = VRNNCell(5)
x = tf.keras.Input((None, 32))
layer = tf.keras.layers.RNN(cell)
y = layer(x)

In [121]:
from tensorflow import keras
batch_size = 64
timesteps = 20

cell = VRNNCell(3)
vrnn = keras.layers.RNN(cell)
input_1 = keras.Input((None, 32))

outputs = vrnn(input_1, training=True)

model = keras.models.Model(input_1, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])


In [122]:
input_1_data = np.random.random((batch_size * 1, timesteps, 32))
target_1_data = np.random.random((batch_size * 1, 3))
model.fit(input_1_data, target_1_data, batch_size=1)



<tensorflow.python.keras.callbacks.History at 0x17c8a6b50>

In [31]:
a = tf.convert_to_tensor(np.ones([1,4]).astype('float32'))
b = tf.convert_to_tensor(np.ones([4,3]).astype('float32'))