In [1]:
import numpy as np
import tensorflow as tf
from tensorflow.contrib.rnn import RNNCell, GRUCell

In [2]:
class FirstWrapper(RNNCell):
    def __init__(self, cell1, cell2, layer_sizes):
        super(FirstWrapper, self).__init__()
        self._cell1 = cell1
        self._cell2 = cell2
        self._layer_sizes = layer_sizes
        
    @property
    def state_size(self):
        return [self._cell1.state_size, self._cell2.state_size]
    
    @property
    def output_size(self):
        return self._cell2.output_size
    
    def call(self, inputs, state):
        output, state1 = self._cell1(inputs, state[0])
        
        x = output
        with tf.variable_scope('prenet') as scope:
            for i, num in enumerate(self._layer_sizes, 1):
                dense = tf.layers.dense(x, num, tf.nn.relu, name='dense_%d'%(i))
                x = tf.layers.dropout(dense, name='dropout_%d'%(i))
                
        output = x
        
        new_output, state2 = self._cell2(output, state[1])
        new_state = [state1, state2]
        return new_output, [state1, state2]
    
    def zero_state(self, batch_size, dtype):
        return [self._cell1.zero_state(batch_size, dtype), self._cell2.zero_state(batch_size, dtype)]

In [3]:
test_cell = FirstWrapper(GRUCell(128, name='cell1'), GRUCell(16, name='cell2'), [64, 32])

In [4]:
data = tf.placeholder(dtype=tf.float32, shape=[32, 10, 100], name='input_data')

In [5]:
outputs, state = tf.nn.dynamic_rnn(cell=test_cell, inputs=data, dtype=tf.float32)

In [6]:
outputs

<tf.Tensor 'rnn/transpose_1:0' shape=(32, 10, 16) dtype=float32>

In [7]:
tf.trainable_variables()

[<tf.Variable 'rnn/first_wrapper/cell1/gates/kernel:0' shape=(228, 256) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/cell1/gates/bias:0' shape=(256,) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/cell1/candidate/kernel:0' shape=(228, 128) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/cell1/candidate/bias:0' shape=(128,) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/prenet/dense_1/kernel:0' shape=(128, 64) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/prenet/dense_1/bias:0' shape=(64,) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/prenet/dense_2/kernel:0' shape=(64, 32) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/prenet/dense_2/bias:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/cell2/gates/kernel:0' shape=(48, 32) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/cell2/gates/bias:0' shape=(32,) dtype=float32_ref>,
 <tf.Variable 'rnn/first_wrapper/cell2/candidate/kernel:0' shape=(48, 16) dtype=float32_ref>,
 <tf