In [17]:
import tensorflow as tf

from tensorflow.python.layers import base as base_layer
from tensorflow.python.ops.rnn_cell_impl import _LayerRNNCell
from tensorflow.python.ops.math_ops import tanh, sigmoid
from tensorflow.python.ops.init_ops import constant_initializer, zeros_initializer

In [19]:
tf.reset_default_graph()

class WaveRNN(_LayerRNNCell):
    
    def __init__(self, 
                 number_units,
                 kernel_intialiser=None,
                 bias_initialiser=None,
                 reuse=None, 
                 name=None):
        
        super(WaveRNN, self).__init__(_reuse=reuse, 
                                      name=name)
        
        # Inputs must be two dimensional.
        self.input_spec = base_layer.InputSpec(ndim=2)
        
        self.number_units = number_units
        self.kernel_initialiser = kernel_intialiser
        self.bias_initialiser = bias_initialiser
        
    @property
    def state_size(self):
        return self.number_units
    
    @property
    def output_size(self):
        return self.number_units
    
    def build(self, inputs_shape):
        input_depth = inputs_shape[1].value
        
        # TO DO: Change kernels and biases to match the vars needed in equation (2)
        #        from the paper.
        
        # Gate kernel.
        gate_kernel_name = 'WaveRNN/gate/kernel'
        gate_kernel_shape = [input_depth + self.number_units, 2 * self.number_units]
        gate_kernel_initialiser = self.kernel_initialiser
        self.gate_kernel = self.add_variable(gate_kernel_name, 
                                             shape=gate_kernel_shape, 
                                             initializer=gate_kernel_initialiser)
        # Gate Bias.
        gate_bias_name = 'WaveRNN/gate/bias'
        gate_bias_shape = [2 * self.number_units]
        gate_bias_initialiser = self.bias_initialiser \
                                if self.bias_initialiser is not None \
                                else constant_initializer(1.0, dtype=self.dtype)
                             
        self.gate_bias = self.add_variable(gate_bias_name, 
                                           shape=gate_bias_shape, 
                                           initializer=gate_bias_initialiser)
        # Candidate Kernel.
        candidate_kernel_name = 'WaveRNN/candidate/kernel'
        candidate_kernel_shape = [input_depth + self.number_units, self.number_units]
        candidate_kernel_initialiser = self.kernel_initialiser
        self.candidate_kernel = self.add_variable(candidate_kernel_name, 
                                                  shape=candidate_kernel_shape, 
                                                  initializer=candidate_kernel_initialiser)
        # Candidate Bias.
        candidate_bias_name = 'WaveRNN/candidate/bias'
        candidate_bias_shape = [self.number_units]
        candidate_bias_initialiser = self.bias_initialiser \
                                     if self.bias_initialiser is not None \
                                     else zeros_initializer(dtype=self.dtype)
        self.candidate_bias = self.add_variable(candidate_bias_name, 
                                                shape=candidate_bias_shape, 
                                                initializer=candidate_bias_initialiser)
        
        self.built = True
        
        
    def call(self, inputs, state):
        
        # Create Equation (2) from paper here.
        pass