In [21]:
import tensorflow as tf
import numpy as np
# Implementing Variational RNN's and variations by subclassing Keras RNN-type Cells

class VRNNCell(tf.keras.layers.GRUCell):
    def __init__(self, units, **kwargs):
        super(VRNNCell, self).__init__(units, **kwargs)
    

    def build(self, input_shape):
        # Taking most of the standard weights from the base GRU class
        super().build((input_shape[0], input_shape[1] + self.units))
        self.encoder_mu_kernel = self.add_weight(shape=(input_shape[-1] + self.units, self.units),
                                                 initializer='uniform',
                                                 name='kernel')
        
        self.encoder_logvar_kernel = self.add_weight(shape=(input_shape[-1] + self.units, self.units),
                                      initializer='uniform',
                                      name='kernel')
        
        self.prior_mu_kernel = self.add_weight(shape=(self.units, self.units),
                              initializer='uniform',
                              name='kernel')
        
        self.prior_logvar_kernel = self.add_weight(shape=(self.units,self.units),
                                      initializer='uniform',
                                      name='kernel')
    def kl_gauss(self, posterior_means, prior_means, posterior_log_var, prior_log_var):   
        kl = prior_log_var - posterior_log_var + (tf.exp(posterior_log_var) + 
                                           tf.square(posterior_means - prior_means)) / tf.exp(prior_log_var) - 1
        kl = 0.5 * tf.reduce_sum(kl, axis=1)
        return kl


    def sample(self, mu, log_var):
        # Sample from unit Normal
        epsilon = tf.random.normal([1, self.units])
        half_constant = tf.convert_to_tensor(np.full((1, self.units), 0.5).astype('float32'))
        # All element-wise computations
        z = tf.math.multiply(half_constant, tf.math.exp(log_var)) + mu
        return z
    
    def call(self, inputs, states, training=False):
        # Some formulations:
        # Generation:
        # z_t ~ N(mu_(0, t), sigma_(0,t)), w here [mu_(0,t), sigma(0,t)] = phi_prior(h_(t-1))
        # Update: 
        # h_t = f_theta(h_(t-1), z_t, x_t) *recurrence equation
        # Inference:
        # z_t ~ N(mu_z, sigma_z), where [mu_z, sigma_z] = phi_post(x_t, h_(t-1))
        #
        # Let the base RNN cell handle the rest and add loss
        
        if training:
            x_t = inputs
            h_prev = states[0]

            p_mu = tf.matmul(h_prev, self.prior_mu_kernel)
            p_logvar = tf.matmul(h_prev, self.prior_logvar_kernel)
            
            input_state_concat = tf.concat([x_t, h_prev], axis=1)
            
            q_mu = tf.matmul(input_state_concat, self.encoder_mu_kernel)
            q_logvar = tf.matmul(input_state_concat, self.encoder_logvar_kernel)
            z_t = self.sample(q_mu, q_logvar)
            
            inp = tf.concat([x_t, z_t], axis=1)
            _, h_next = super().call(inp, h_prev)
            
            output = z_t
            new_state = (h_next, q_mu, p_mu, q_logvar, p_logvar)
            self.add_loss(lambda: tf.reduce_sum(tf.square(inputs)))
            return output, [h_next]
        
        else:
            # Return prior and posterior parameters
            x_t = inputs
            h_prev = states[0]

            p_mu = tf.matmul(h_prev, self.prior_mu_kernel)
            p_logvar = tf.matmul(h_prev, self.prior_logvar_kernel)
            z_t = self.sample(p_mu, p_logvar)
            
            input_state_concat = tf.concat([x_t, h_prev], axis=1)
            
            q_mu = tf.matmul(input_state_concat, self.encoder_mu_kernel)
            q_logvar = tf.matmul(input_state_concat, self.encoder_logvar_kernel)
            
            
            i = tf.concat([x_t, z_t], axis=1)
            _, h_next = super().call(i, h_prev)
            
            output = (z_t, p_mu, p_logvar, q_mu, q_logvar)
            
            return output, h_next
    
   
    def get_config(self):
        return {"units":self.units}

In [22]:
cell = VRNNCell(5)
x = tf.keras.Input((None, 32))
layer = tf.keras.layers.RNN(cell)
y = layer(x)

In [23]:
from tensorflow import keras
batch_size = 64
timesteps = 20

cell = VRNNCell(3)
vrnn = keras.layers.RNN(cell)
input_1 = keras.Input((None, 32))

outputs = vrnn(input_1, training=True)

model = keras.models.Model(input_1, outputs)
model.compile(optimizer="adam", loss="mse", metrics=["accuracy"])


In [24]:
input_1_data = np.random.random((batch_size * 1, timesteps, 32))
target_1_data = np.random.random((batch_size * 1, 3))
model.fit(input_1_data, target_1_data, batch_size=1)

InaccessibleTensorError: in user code:

    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:805 train_function  *
        return step_function(self, iterator)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:795 step_function  **
        outputs = model.distribute_strategy.run(run_step, args=(data,))
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:1259 run
        return self._extended.call_for_each_replica(fn, args=args, kwargs=kwargs)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:2730 call_for_each_replica
        return self._call_for_each_replica(fn, args, kwargs)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/distribute/distribute_lib.py:3417 _call_for_each_replica
        return fn(*args, **kwargs)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:788 run_step  **
        outputs = model.train_step(data)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/training.py:756 train_step
        y, y_pred, sample_weight, regularization_losses=self.losses)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:1477 losses
        loss_tensor = regularizer()
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/keras/engine/base_layer.py:1553 _tag_callable
        loss = loss()
    /var/folders/q_/3vc146c52nbdzbv061fwl40w0000gn/T/tmp03szqkil.py:33 <lambda>
        ag__.converted_call(ag__.ld(self).add_loss, (ag__.autograph_artifact((lambda : ag__.converted_call(ag__.ld(tf).reduce_sum, (ag__.converted_call(ag__.ld(tf).square, (ag__.ld(inputs),), None, fscope),), None, fscope))),), None, fscope)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/ops/gen_math_ops.py:10175 square  **
        "Square", x=x, name=name)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/framework/op_def_library.py:750 _apply_op_helper
        attrs=attr_protos, op_def=op_def)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:588 _create_op_internal
        inp = self.capture(inp)
    /Users/justinlee/Desktop/Master/Thesis/venv/lib/python3.7/site-packages/tensorflow/python/framework/func_graph.py:638 capture
        % (tensor, tensor.graph, self))

    InaccessibleTensorError: The tensor 'Tensor("model_1/rnn_13/while/TensorArrayV2Read/TensorListGetItem:0", shape=(1, 32), dtype=float32)' cannot be accessed here: it is defined in another function or code block. Use return values, explicit Python locals or TensorFlow collections to access it. Defined in: FuncGraph(name=model_1_rnn_13_while_body_3890, id=6236714576); accessed from: FuncGraph(name=train_function, id=6237353552).
    


In [11]:
vrnn(input_1_data, training=True)

<tf.Tensor: shape=(64, 3), dtype=float32, numpy=
array([[0.4311453 , 0.4606725 , 0.47704563],
       [0.52808094, 0.6057584 , 0.57910377],
       [0.46645677, 0.5354863 , 0.54857904],
       [0.48584574, 0.48395637, 0.48264694],
       [0.50491244, 0.48578846, 0.38260394],
       [0.36745906, 0.5687973 , 0.54191303],
       [0.5932691 , 0.45675448, 0.46148163],
       [0.59921736, 0.41047204, 0.5037082 ],
       [0.48115534, 0.56768554, 0.4255946 ],
       [0.46520266, 0.42949104, 0.51351035],
       [0.39687178, 0.50279284, 0.52661777],
       [0.52611357, 0.46462557, 0.51956886],
       [0.5612152 , 0.4806883 , 0.47820187],
       [0.5042616 , 0.5365728 , 0.4162539 ],
       [0.5764899 , 0.39615476, 0.44143268],
       [0.5510044 , 0.45897853, 0.5210611 ],
       [0.49854225, 0.48809254, 0.4138353 ],
       [0.53186536, 0.46432456, 0.4532896 ],
       [0.5432695 , 0.4237479 , 0.49098617],
       [0.40357178, 0.47799325, 0.40562063],
       [0.5237373 , 0.4983244 , 0.5457534 ],
      

In [13]:
vrnn(input_1_data, training=False)

(<tf.Tensor: shape=(64, 3), dtype=float32, numpy=
 array([[0.48544663, 0.50757957, 0.4699906 ],
        [0.50136703, 0.50307465, 0.49709392],
        [0.48719135, 0.50693065, 0.4732785 ],
        [0.4811251 , 0.5139315 , 0.45654422],
        [0.48267457, 0.5189062 , 0.45178062],
        [0.4924169 , 0.5194061 , 0.46450275],
        [0.4944159 , 0.5097243 , 0.47875574],
        [0.4922774 , 0.5230228 , 0.4596571 ],
        [0.4844971 , 0.51291144, 0.46177593],
        [0.4838476 , 0.52520555, 0.44552037],
        [0.4860503 , 0.50793403, 0.46988666],
        [0.48626116, 0.51788074, 0.45810208],
        [0.5016757 , 0.51223445, 0.48537055],
        [0.4957752 , 0.527188  , 0.4598345 ],
        [0.48279372, 0.5172111 , 0.45413494],
        [0.4909631 , 0.5158375 , 0.46712944],
        [0.5005782 , 0.5119957 , 0.48472178],
        [0.48909238, 0.52358216, 0.45513284],
        [0.4848975 , 0.51882964, 0.4550278 ],
        [0.49085367, 0.5058539 , 0.47870076],
        [0.48968858, 0.500719 