In [1]:
from models.modules import *

In [2]:
import tensorflow as tf
from tensorflow.contrib.rnn import GRUCell, MultiRNNCell, OutputProjectionWrapper, ResidualWrapper
from tensorflow.contrib.seq2seq import BasicDecoder, BahdanauAttention, AttentionWrapper
from text.symbols import symbols
from hparams import hparams, hparams_debug_string
from models.helpers import TacoTestHelper, TacoTrainingHelper
from models.modules import encoder_cbhg, post_cbhg, prenet
from models.rnn_wrappers import DecoderPrenetWrapper, ConcatOutputAndAttentionWrapper

In [3]:
batch_size = 1
input_len = 10
output_len = 100

is_training = True

hparams.parse('')
hp = hparams

In [4]:
inputs = tf.placeholder(tf.int32, [batch_size, input_len], 'inputs')
input_lengths = tf.placeholder(tf.int32, [batch_size], 'input_lengths')
mel_targets = tf.placeholder(tf.float32, [batch_size, output_len, 80], 'mel_targets')
linear_targets = tf.placeholder(tf.float32, [batch_size, output_len, 1025], 'linear_targets')

In [5]:
embedding_table = tf.get_variable(
    'embedding', [len(symbols), hp.embed_depth], dtype=tf.float32,
    initializer=tf.truncated_normal_initializer(stddev=0.5))
embedded_inputs = tf.nn.embedding_lookup(embedding_table, inputs)

prenet_outputs = prenet(embedded_inputs, is_training, hp.prenet_depths)
encoder_outputs = encoder_cbhg(prenet_outputs, input_lengths, is_training, hp.encoder_depth)

Instructions for updating:
seq_dim is deprecated, use seq_axis instead
Instructions for updating:
batch_dim is deprecated, use batch_axis instead


In [6]:
g_cell = GRUCell(hp.attention_depth)

In [7]:
bd_atten = BahdanauAttention(hp.attention_depth, encoder_outputs)

In [8]:
attention_cell = AttentionWrapper(g_cell, bd_atten, alignment_history=True, output_attention=False)

In [9]:
attention_cell = DecoderPrenetWrapper(attention_cell, is_training, hp.prenet_depths)

In [10]:
concat_cell = ConcatOutputAndAttentionWrapper(attention_cell)

In [11]:
decoder_cell = MultiRNNCell([
                    OutputProjectionWrapper(concat_cell, hp.decoder_depth),
                    ResidualWrapper(GRUCell(hp.decoder_depth)),
                    ResidualWrapper(GRUCell(hp.decoder_depth))
                ], state_is_tuple=True)

In [12]:
output_cell = OutputProjectionWrapper(decoder_cell, hp.num_mels * hp.outputs_per_step)
decoder_init_state = output_cell.zero_state(batch_size=batch_size, dtype=tf.float32)

In [13]:
if is_training:
    helper = TacoTrainingHelper(inputs, mel_targets, hp.num_mels, hp.outputs_per_step)
else:
    helper = TacoTestHelper(batch_size, hp.num_mels, hp.outputs_per_step)

In [14]:
(decoder_outputs, _), final_decoder_state, _ = tf.contrib.seq2seq.dynamic_decode(
    BasicDecoder(output_cell, helper, decoder_init_state),
    maximum_iterations=hp.max_iters)

In [15]:
mel_outputs = tf.reshape(decoder_outputs, [batch_size, -1, hp.num_mels])

In [17]:
post_outputs = post_cbhg(mel_outputs, hp.num_mels, is_training, hp.postnet_depth)

In [18]:
linear_outputs = tf.layers.dense(post_outputs, hp.num_freq)

In [19]:
alignments = tf.transpose(final_decoder_state[0].alignment_history.stack(), [1, 2, 0])

In [25]:
mel_targets

<tf.Tensor 'mel_targets:0' shape=(1, 100, 80) dtype=float32>

In [26]:
mel_outputs

<tf.Tensor 'Reshape:0' shape=(1, ?, 80) dtype=float32>