In [1]:
from tensorflow.keras.layers import Dense, Flatten, Layer, Reshape, Concatenate
import tensorflow as tf
import numpy as np

In [12]:
class Autoregressive_input_layer(Layer):
    """
    Autoregressive inputs refer to inputs were we need to maintain autoregressive structure as in MADE paper
    non_autoregressive inputs are fully connected (this represents h in IAF paper)
    """
    def __init__(self, autoregressive_input_dim, non_autoregressive_input_dim, units=32):
        super(Autoregressive_input_layer, self).__init__()
        assert units >= autoregressive_input_dim
        weight_initialiser = tf.random_normal_initializer()
        self.autoregressive_weights = \
            tf.Variable(initial_value=weight_initialiser(shape=(autoregressive_input_dim, units),
                        dtype="float32"), trainable=True)
        self.autoregressive_weights_mask = np.zeros((autoregressive_input_dim, units))
        # TODO can do this without a lengthy loop
        for i in range(autoregressive_input_dim):
            for j in range(units):
                units_corresponding_max_autoregressive_input_index = j//autoregressive_input_dim
                if units_corresponding_max_autoregressive_input_index > i:
                    self.autoregressive_weights_mask[i, j] = 1
        self.autoregressive_weights_mask = tf.convert_to_tensor(self.autoregressive_weights_mask, dtype="float32")
        #self.L_mask = np.zeros((autoregressive_input_dim, units)).astype("float32")
        #self.L_mask = np.tril(self.L_mask, k=-1) # we can do this if units=autoregressive_input_dim

        self.non_autoregressive_weights = \
            tf.Variable(initial_value=weight_initialiser(shape=(non_autoregressive_input_dim, units),
                                                         dtype="float32"), trainable=True)
        self.biases = tf.Variable(initial_value=tf.zeros_initializer()(shape=(units, ), dtype="float32"),
                                  trainable=True)

In [18]:
input_layer = Autoregressive_input_layer(2, 3, units = 5)

In [19]:
input_layer.autoregressive_weights

<tf.Variable 'Variable:0' shape=(2, 5) dtype=float32, numpy=
array([[-0.06765819,  0.12378453,  0.03438913, -0.01846073,  0.04576434],
       [ 0.06933082, -0.08079021, -0.00821319,  0.01058174, -0.01392082]],
      dtype=float32)>

In [20]:
input_layer.autoregressive_weights_mask

<tf.Tensor: shape=(2, 5), dtype=float32, numpy=
array([[0., 0., 1., 1., 1.],
       [0., 0., 0., 0., 1.]], dtype=float32)>

In [None]:
latent_z = np.array([1, 100, 10000])