In [1]:
import tensorflow as tf
from tensorflow.python.layers import core as layers_core
from tensorflow.python.layers import normalization as layers_norm
import numpy as np
import time
from tensorflow.python.framework import ops
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import math_ops
from tensorflow.contrib.seq2seq import AttentionWrapperState
from tensorflow.contrib.seq2seq import AttentionMechanism
from tensorflow.python.ops import rnn_cell_impl
_zero_state_tensors = rnn_cell_impl._zero_state_tensors
from tensorflow.python.util import nest
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import variable_scope
from tensorflow.contrib.seq2seq import BahdanauAttention
from tensorflow.python.ops import nn_ops
import collections, math
from tensorflow.python.ops import init_ops

In [2]:
def _compute_attention(attention_mechanism, cell_output, previous_alignments,
                       attention_layer, attention_dropout_layer, training):
    """Computes the attention and alignments for a given attention_mechanism."""
    alignments = attention_mechanism(cell_output, previous_alignments=previous_alignments)

    expanded_alignments = array_ops.expand_dims(alignments, 1)
    context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
    context = array_ops.squeeze(context, [1])

    if attention_layer is not None:
        #attention = attention_layer(array_ops.concat([cell_output, context], 1))
        attention = attention_dropout_layer(attention_layer(array_ops.concat([cell_output, context], 1)), training=training)
    else:
        attention = context

    return attention, alignments    
    
class MyAttentionWrapper(tf.contrib.seq2seq.AttentionWrapper):
    
    def __init__(self, cell, attention_mechanism, keep_prob, training, attention_layer_size=None, alignment_history=False, cell_input_fn=None,
                output_attention=True, initial_cell_state=None, name=None):
        super(MyAttentionWrapper, self).__init__(cell, attention_mechanism, attention_layer_size, alignment_history, cell_input_fn,
                output_attention, initial_cell_state, name)
        
        self.keep_prob = keep_prob
        self.training = training
        
        super(tf.contrib.seq2seq.AttentionWrapper, self).__init__(name=name)
        if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
            raise TypeError("cell must be an RNNCell, saw type: %s" % type(cell).__name__)
        if isinstance(attention_mechanism, (list, tuple)):
            self._is_multi = True
            attention_mechanisms = attention_mechanism
            for attention_mechanism in attention_mechanisms:
                if not isinstance(attention_mechanism, AttentionMechanism):
                    raise TypeError("attention_mechanism must contain only instances of "
                                    "AttentionMechanism, saw type: %s" % type(attention_mechanism).__name__)
        else:
            self._is_multi = False
            if not isinstance(attention_mechanism, AttentionMechanism):
                raise TypeError("attention_mechanism must be an AttentionMechanism or list of "
                                "multiple AttentionMechanism instances, saw type: %s" % type(attention_mechanism).__name__)
            attention_mechanisms = (attention_mechanism,)

        if cell_input_fn is None:
            cell_input_fn = (lambda inputs, attention: array_ops.concat([inputs, attention], -1))
        else:
            if not callable(cell_input_fn):
                raise TypeError("cell_input_fn must be callable, saw type: %s" % type(cell_input_fn).__name__)

        if attention_layer_size is not None:
            attention_layer_sizes = tuple(attention_layer_size if isinstance(attention_layer_size, (list, tuple))
                                              else (attention_layer_size,))
            if len(attention_layer_sizes) != len(attention_mechanisms):
                raise ValueError("If provided, attention_layer_size must contain exactly one "
                                "integer per attention_mechanism, saw: %d vs %d" % (len(attention_layer_sizes), len(attention_mechanisms)))
            self._attention_layers = tuple(layers_core.Dense(attention_layer_size, name="attention_layer", use_bias=False, activation=tf.tanh)
                                                  for attention_layer_size in attention_layer_sizes)
            self._attention_dropout_layers = tuple(layers_core.Dropout(rate=1-self.keep_prob, name="attention_dropout_layer")
                                                  for attention_layer_size in attention_layer_sizes)
            self._attention_layer_size = sum(attention_layer_sizes)
        else:
            self._attention_layers = None
            self._attention_dropout_layers = None
            self._attention_layer_size = sum(attention_mechanism.values.get_shape()[-1].value
                                                      for attention_mechanism in attention_mechanisms)
            
        self._cell = cell
        self._attention_mechanisms = attention_mechanisms
        self._cell_input_fn = cell_input_fn
        self._output_attention = output_attention
        self._alignment_history = alignment_history
        with ops.name_scope(name, "AttentionWrapperInit"):
            if initial_cell_state is None:
                self._initial_cell_state = None
            else:
                final_state_tensor = nest.flatten(initial_cell_state)[-1]
                state_batch_size = (final_state_tensor.shape[0].value or array_ops.shape(final_state_tensor)[0])
                error_message = (
                "When constructing AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and initial_cell_state.  Are you using "
                "the BeamSearchDecoder?  You may need to tile your initial state "
                "via the tf.contrib.seq2seq.tile_batch function with argument "
                "multiple=beam_width.")
                with ops.control_dependencies(
                    self._batch_size_checks(state_batch_size, error_message)):
                    self._initial_cell_state = nest.map_structure(lambda s: array_ops.identity(s, name="check_initial_cell_state"),
                                                              initial_cell_state)
        
    
    @property
    def state_size(self):
        return AttentionWrapperState(
            cell_state=self._cell.state_size,
            time=tensor_shape.TensorShape([]),
            attention=self._attention_layer_size,
            alignments=self._item_or_tuple(
                a.alignments_size for a in self._attention_mechanisms),
            alignment_history=self._item_or_tuple(
                () for _ in self._attention_mechanisms)) # sometimes a TensorArray
    
    def zero_state(self, batch_size, dtype):
        with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
            if self._initial_cell_state is not None:
                cell_state = self._initial_cell_state
            else:
                cell_state = self._cell.zero_state(batch_size, dtype)
            error_message = (
                "When calling zero_state of AttentionWrapper %s: " % self._base_name +
                "Non-matching batch sizes between the memory "
                "(encoder output) and the requested batch size.  Are you using "
                "the BeamSearchDecoder?  If so, make sure your encoder output has "
                "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
                "the batch_size= argument passed to zero_state is "
                "batch_size * beam_width.")
            with ops.control_dependencies(self._batch_size_checks(batch_size, error_message)):
                cell_state = nest.map_structure(
                  lambda s: array_ops.identity(s, name="checked_cell_state"),
                  cell_state)                
            
            return AttentionWrapperState(
                cell_state=cell_state,
                time=array_ops.zeros([], dtype=dtypes.int32),
                attention=_zero_state_tensors(self._attention_layer_size, batch_size,
                                        dtype),
                alignments=self._item_or_tuple(
                    attention_mechanism.initial_alignments(batch_size, dtype)
                    for attention_mechanism in self._attention_mechanisms),
                alignment_history=self._item_or_tuple(
                    tensor_array_ops.TensorArray(dtype=dtype, size=0,
                                                 dynamic_size=True)
                    if self._alignment_history else ()
                    for _ in self._attention_mechanisms))
    
    def call(self, inputs, state):

        if not isinstance(state, AttentionWrapperState):
            raise TypeError("Expected state to be instance of MyAttentionWrapperState. "
                          "Received type %s instead."  % type(state))
        
        # Step 1: Calculate the true inputs to the cell based on the
        # previous attention value.
        cell_inputs = self._cell_input_fn(inputs, state.attention)
        cell_state = state.cell_state
        cell_output, next_cell_state = self._cell(cell_inputs, cell_state)

        cell_batch_size = (
            cell_output.shape[0].value or array_ops.shape(cell_output)[0])
        error_message = (
            "When applying AttentionWrapper %s: " % self.name +
            "Non-matching batch sizes between the memory "
            "(encoder output) and the query (decoder output).  Are you using "
            "the BeamSearchDecoder?  You may need to tile your memory input via "
            "the tf.contrib.seq2seq.tile_batch function with argument "
            "multiple=beam_width.")
        with ops.control_dependencies(
            self._batch_size_checks(cell_batch_size, error_message)):
            cell_output = array_ops.identity(
              cell_output, name="checked_cell_output")

        if self._is_multi:
            previous_alignments = state.alignments
            previous_alignment_history = state.alignment_history
        else:
            previous_alignments = [state.alignments]
            previous_alignment_history = [state.alignment_history]

        all_alignments = []
        all_attentions = []
        all_histories = []
        for i, attention_mechanism in enumerate(self._attention_mechanisms):
            attention, alignments = _compute_attention(
            attention_mechanism, cell_output, previous_alignments[i],
            self._attention_layers[i] if self._attention_layers else None,
            self._attention_dropout_layers[i] if self._attention_dropout_layers else None,
            self.training)
            alignment_history = previous_alignment_history[i].write(
            state.time, alignments) if self._alignment_history else ()

            all_alignments.append(alignments)
            all_histories.append(alignment_history)
            all_attentions.append(attention)

        attention = array_ops.concat(all_attentions, 1)
        
        next_state = AttentionWrapperState(
            time=state.time + 1,
            cell_state=next_cell_state,
            attention=attention,
            alignments=self._item_or_tuple(all_alignments),
            alignment_history=self._item_or_tuple(all_histories))
        
        if self._output_attention:
            return attention, next_state
        else:
            return cell_output, next_state

In [3]:
phn_61 = ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']
phn_39 = ['ae', 'ao', 'aw', 'ax', 'ay', 'b', 'ch', 'd', 'dh', 'dx', 'eh', 'er', 'ey', 'f', 'g', 'h#', 'hh', 'ix', 'iy', 'jh', 'k', 'l', 'm', 'n', 'ng', 'ow', 'oy', 'p', 'r', 's', 't', 'th', 'uh', 'uw', 'v', 'w', 'y', 'z', 'zh']
mapping = {'ah': 'ax', 'ax-h': 'ax', 'ux': 'uw', 'aa': 'ao', 'ih': 'ix', 'axr': 'er', 'el': 'l', 'em': 'm', 'en': 'n', 'nx': 'n', 'eng': 'ng', 'sh': 'zh', 'hv': 'hh', 'bcl': 'h#', 'pcl': 'h#', 'dcl': 'h#', 'tcl': 'h#', 'gcl': 'h#', 'kcl': 'h#', 'q': 'h#', 'epi': 'h#', 'pau': 'h#'}

TRAIN_FILE = './data/fbank/train.tfrecords'
DEV_FILE = './data/fbank/dev.tfrecords'
TEST_FILE = './data/fbank/test.tfrecords'
checkpoints_path = './model/seq2seq_debug/ckpt'
ft_checkpoints_path = './model/seq2seq_debug/finetunning/ckpt'
#final_checkpoints_path = './model/seq2seq_debug/final/ckpt'

feat_type = 'fbank'
feats_dim = 39 if feat_type=='mfcc' else 123
labels_sos_id = len(phn_61)
labels_eos_id = len(phn_61) + 1
num_classes = len(phn_61) + 2

num_unit_encoder = 256
num_unit_decoder = 256
learning_rate = 0.001
n_hidden_layer = 3
beam_width = 10
batch_size = 32
epochs = 100

In [4]:
class Model(object):
    def __init__(self, batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id, labels_eos_id,
                 learning_rate=0.001, optimizer=None, beam_width=10, phn_61=None, phn_39=None, mapping=None, file_type=None, model_type=None):
        iterator = self._get_iterator(batch_size, feats_dim, labels_sos_id, labels_eos_id, file_type, model_type)
        self.iterator_initializer = iterator.initializer
        
        batched_data = iterator.get_next()
        features = batched_data[0]
        labels_in = batched_data[1]
        labels_out = batched_data[2]
        feats_seq_len = tf.to_int32(batched_data[3])
        labels_in_seq_len = tf.to_int32(batched_data[4])
        
        decoded = self._compute_dynamic_decode(features, labels_in, feats_seq_len, labels_in_seq_len, num_unit_encoder,
                                               num_unit_decoder, n_hidden_layer, num_classes, beam_width, model_type)
        self.saver = tf.train.Saver() # save graph until decoded
        
        self.optimizer = optimizer
        self.model_type = model_type
        
        if model_type == 'train':
            logits = decoded
            self.loss = self._compute_loss(labels_out, labels_in_seq_len, logits)
            self.update_step = self._get_update_step(self.loss, learning_rate, optimizer)
        elif model_type == 'eval':
            logits = decoded
            self.loss = self._compute_loss(labels_out, labels_in_seq_len, logits)
        else: # model_type == 'infer'
            predicted_ids = decoded
            self.per = self._compute_per(predicted_ids, labels_out, labels_eos_id, phn_61, phn_39, mapping)
            
    def _get_sparse_tensor(self, dense, default):
        indices = tf.to_int64(tf.where(tf.not_equal(dense, default)))
        vals = tf.to_int32(tf.gather_nd(dense, indices))
        shape = tf.to_int64(tf.shape(dense))
        return tf.SparseTensor(indices, vals, shape)
        
    def _get_iterator(self, batch_size, feats_dim, labels_sos_id, labels_eos_id, file_type, model_type):
        dataset = tf.contrib.data.TFRecordDataset(file_type)
        context_features = {'feats_seq_len': tf.FixedLenFeature([], dtype=tf.int64),
                           'labels_seq_len': tf.FixedLenFeature([], dtype=tf.int64)}
        sequence_features = {'features': tf.FixedLenSequenceFeature([feats_dim], dtype=tf.float32),
                            'labels': tf.FixedLenSequenceFeature([], dtype=tf.int64)}
        dataset = dataset.map(lambda serialized_example: tf.parse_single_sequence_example(serialized_example,
                                                                                         context_features=context_features,
                                                                                         sequence_features=sequence_features))
        dataset = dataset.map(lambda context, sequence: (sequence['features'], sequence['labels'], 
                                                         context['feats_seq_len'], context['labels_seq_len']))
        dataset = dataset.map(lambda features, labels, feats_seq_len, labels_seq_len: (features, 
                                                        tf.concat(([labels_sos_id], labels),0),
                                                        tf.concat((labels, [labels_eos_id]), 0),
                                                                feats_seq_len, labels_seq_len))
        dataset = dataset.map(lambda features, labels_in, labels_out, feats_seq_len, labels_seq_len: 
                                          (features, labels_in, labels_out, feats_seq_len, tf.size(labels_in, out_type=tf.int64)))
        def batching_func(x):
            return x.padded_batch(batch_size,
                                 padded_shapes=(tf.TensorShape([None, feats_dim]),
                                               tf.TensorShape([None]),
                                               tf.TensorShape([None]),
                                               tf.TensorShape([]),
                                               tf.TensorShape([])),
                                 padding_values=(tf.cast(0, tf.float32),
                                                tf.cast(labels_eos_id, tf.int64),
                                                tf.cast(labels_eos_id, tf.int64),
                                                tf.cast(0, tf.int64),
                                                tf.cast(0, tf.int64)))
        def key_func(features, labels_in, labels_out, feats_seq_len, labels_in_seq_len):
            f0 = lambda: tf.constant(0, tf.int64)
            f1 = lambda: tf.constant(1, tf.int64)
            f2 = lambda: tf.constant(2, tf.int64)
            f3 = lambda: tf.constant(3, tf.int64)
            f4 = lambda: tf.constant(4, tf.int64)
            f5 = lambda: tf.constant(5, tf.int64)
            f6 = lambda: tf.constant(6, tf.int64)
            
            return tf.case([(tf.less_equal(feats_seq_len, 200), f0),
                   (tf.less_equal(feats_seq_len, 250), f1),
                   (tf.less_equal(feats_seq_len, 300), f2),
                   (tf.less_equal(feats_seq_len, 350), f3),
                   (tf.less_equal(feats_seq_len, 400), f4),
                   (tf.less_equal(feats_seq_len, 500), f5)], default=f6)
        
        def reduce_func(bucket_id, windowed_data):
            return batching_func(windowed_data)
        
        if model_type=='train':
            dataset = dataset.shuffle(10000)
            batched_dataset = dataset.group_by_window(key_func=key_func, reduce_func=reduce_func, window_size=batch_size)
            batched_dataset = batched_dataset.shuffle(10000)
        else:
            batched_dataset = batching_func(dataset)
            
        return batched_dataset.make_initializable_iterator()
    
    def _compute_encoder_outputs(self, features, feats_seq_len, num_unit_encoder, n_hidden_layer):
        self.keep_prob = tf.placeholder(tf.float32)
        self.training = tf.placeholder(tf.bool)
        
        def residual_block(inp, out_channels):
            inp_channels = inp.get_shape().as_list()[3]
            
            out = tf.layers.conv2d(inp, filters=out_channels, kernel_size=(3,3), strides=(1,1), padding='same')
            out = tf.layers.batch_normalization(out, training=self.training) 
            out = tf.nn.relu(out)
            out = tf.layers.dropout(out, rate=1-self.keep_prob, training=self.training)
            out = tf.layers.conv2d(out, filters=out_channels, kernel_size=(3,3), strides=(1,1), padding='same')
            out = tf.layers.batch_normalization(out, training=self.training) 
            out = tf.nn.relu(out)
            out = tf.layers.dropout(out, rate=1-self.keep_prob, training=self.training)
        
            if inp_channels != out_channels:
                inp = tf.layers.conv2d(inp, filters=out_channels, kernel_size=(1,1), strides=(1,1), padding='same')

            return out + inp
        
        def xception_block(inp, out_channels):
            inp_channels = inp.get_shape().as_list()[3]
            
            out = tf.layers.separable_conv2d(inp, filters=out_channels, kernel_size=(3,3), strides=(1,1), padding='same')
            out = tf.layers.batch_normalization(out, training=self.training) 
            out = tf.nn.relu(out)
            out = tf.layers.dropout(out, rate=1-self.keep_prob, training=self.training)
            out = tf.layers.separable_conv2d(out, filters=out_channels, kernel_size=(3,3), strides=(1,1), padding='same')
            out = tf.layers.batch_normalization(out, training=self.training) 
            out = tf.nn.relu(out)
            out = tf.layers.dropout(out, rate=1-self.keep_prob, training=self.training)
        
            if inp_channels != out_channels:
                inp = tf.layers.conv2d(inp, filters=out_channels, kernel_size=(1,1), strides=(1,1), padding='same')

            return out + inp
        
        features = tf.stack(tf.split(features, num_or_size_splits=3, axis=2), axis=3)
        features = tf.transpose(features, [0,2,1,3]) # shape = [batch, feats_dim/3, max_time, channels]
        
        conv = tf.layers.conv2d(features, filters=128, kernel_size=(3,3), strides=(3,3), padding='same') # 41 -> 14, time -> time/3
        conv = tf.layers.batch_normalization(conv, training=self.training) 
        conv = tf.nn.relu(conv)
        conv = tf.layers.dropout(conv, rate=1-self.keep_prob, training=self.training)
        res1 = residual_block(conv,128)
        res2 = residual_block(res1, 128)
        res3 = residual_block(res2, 256)
        res4 = residual_block(res3, 256)
        #res5 = residual_block(res4, 256)
        
        conv = tf.layers.conv2d(res4, filters=512, kernel_size=(14,1), strides=(14,1), padding='same')
        conv = tf.layers.batch_normalization(conv, training=self.training) 
        conv = tf.nn.relu(conv)
        conv = tf.layers.dropout(conv, rate=1-self.keep_prob, training=self.training)
        flattend = tf.transpose(conv, [0,2,1,3])
        flattend = tf.reshape(flattend, [tf.shape(flattend)[0], tf.shape(flattend)[1], 512])
        
        #fc = tf.layers.dense(flattend, 512)
        #fc = tf.layers.batch_normalization(fc, training=self.training) 
        #fc = tf.nn.relu(fc)
        #fc = tf.layers.dropout(fc, rate=1-self.keep_prob, training=self.training)
        
        #cell_fw = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_encoder, use_peepholes=True), output_keep_prob=self.keep_prob)
        #cell_bw = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_encoder, use_peepholes=True), output_keep_prob=self.keep_prob)
        #outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, flattend, sequence_length=feats_seq_len//3, dtype=tf.float32)
        #outputs = tf.concat(outputs, 2)
        #cells_fw = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_encoder, use_peepholes=True,
        #                                        initializer=tf.random_uniform_initializer(minval=-0.01, maxval=0.01)), output_keep_prob=self.keep_prob) for _ in range(3)]
        #cells_bw = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_encoder, use_peepholes=True,
        #                                        initializer=tf.random_uniform_initializer(minval=-0.01, maxval=0.01)), output_keep_prob=self.keep_prob) for _ in range(3)]
        cells_fw = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_encoder, use_peepholes=True),
                                                  output_keep_prob=self.keep_prob) for _ in range(3)]
        cells_bw = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_encoder, use_peepholes=True),
                                                  output_keep_prob=self.keep_prob) for _ in range(3)]
        
        #outputs, _, _ = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw, cells_bw, flattend,
        #                                                    dtype=tf.float32, sequence_length=feats_seq_len//3)
        
        prev_layer = flattend
        prev_seq_len = feats_seq_len//3
        for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
            with tf.variable_scope("cell_%d" % i):
                if i == 2:
                    prev_layer = prev_layer[:,::2,:]
                    prev_seq_len = prev_seq_len // 2
                    
                outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, prev_layer,
                                                                               sequence_length=prev_seq_len, dtype=tf.float32)
                prev_layer = tf.concat(outputs, 2)
                
        
        return prev_layer, prev_seq_len
        
        #return outputs, feats_seq_len//3
    
    def _get_decoder_cell_and_init_state(self, mem_seq_len, num_unit_decoder, n_hidden_layer, memory, beam_width, model_type):
        
        batch_size = tf.shape(memory)[0]
        #decoder_cell = tf.nn.rnn_cell.MultiRNNCell([tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_decoder,
        #                                use_peepholes=True), output_keep_prob=self.keep_prob) for _ in range(3)])
        
        decoder_cell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.LSTMCell(num_unit_decoder, use_peepholes=True),
                                                     output_keep_prob=self.keep_prob)

        if model_type == 'infer':
            memory = tf.contrib.seq2seq.tile_batch(memory, multiplier=beam_width)
            mem_seq_len = tf.contrib.seq2seq.tile_batch(mem_seq_len, multiplier=beam_width)
            batch_size = batch_size * beam_width

        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_unit_decoder, memory, mem_seq_len, scale=False)
        #attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_unit_decoder, memory, mem_seq_len)
        #decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism,
        #                                            attention_layer_size=num_unit_decoder, output_attention=True)
        decoder_cell = MyAttentionWrapper(decoder_cell, attention_mechanism, self.keep_prob, self.training,
                                          attention_layer_size=num_unit_decoder, output_attention=True)
        decoder_initial_state = decoder_cell.zero_state(batch_size, tf.float32)
        
        return(decoder_cell, decoder_initial_state)
    
    def _compute_dynamic_decode(self, features, labels_in, feats_seq_len, labels_in_seq_len, num_unit_encoder, num_unit_decoder,
                        n_hidden_layer, num_classes, beam_width, model_type):
        embedding_decoder = tf.Variable(np.identity(num_classes), dtype=tf.float32, trainable=False)
        #embedding_decoder = tf.Variable(tf.random_uniform([num_classes, 30], minval=-0.1, maxval=0.1, dtype=tf.float32))
        decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, labels_in)
    
        memory, mem_seq_len = self._compute_encoder_outputs(features, feats_seq_len, num_unit_encoder, n_hidden_layer)
        with tf.variable_scope('decoder') as decoder_scope:
            decoder_cell, decoder_initial_state = self._get_decoder_cell_and_init_state(mem_seq_len, num_unit_decoder, n_hidden_layer,
                                                                                        memory, beam_width, model_type)
        
            output_layer = tf.contrib.keras.layers.Dense(num_classes, use_bias=False)

            if model_type in ['train', 'eval']:
                helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, labels_in_seq_len, time_major=False)
                decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state=decoder_initial_state)
                outputs, state, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=False, scope=decoder_scope)
                return output_layer(outputs.rnn_output)
                
            else: # model_type == 'infer'
                beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell, embedding_decoder, 
                                                    tf.fill([tf.shape(features)[0]], labels_sos_id), labels_eos_id,
                                                            decoder_initial_state, beam_width, output_layer=output_layer)
                decoded, state, final_seq_len = tf.contrib.seq2seq.dynamic_decode(beam_decoder, maximum_iterations=100,
                                                                              output_time_major=False, scope=decoder_scope)
                return decoded.predicted_ids # shape = [batch, max_time, beam_width]
    
    def _compute_loss(self, labels_out, labels_in_seq_len, logits):
        max_time = tf.shape(labels_out)[1]
        target_weights = tf.sequence_mask(labels_in_seq_len, max_time, dtype=logits.dtype)
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=labels_out, logits=logits)
        return tf.reduce_sum(crossent * target_weights) / tf.to_float(tf.shape(logits)[0])
    
    def _get_update_step(self, loss, learning_rate, optimizer):
        update_ops = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
        with tf.control_dependencies(update_ops):
            params = tf.trainable_variables()
            gradients = tf.gradients(loss, params)
            clipped_gradients = gradients
            #clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
            if optimizer == 'adam':
                opt = tf.train.AdamOptimizer(learning_rate)
            elif optimizer == 'sgd':
                opt = tf.train.GradientDescentOptimizer(learning_rate)
            elif optimizer =='momentum':
                opt = tf.train.MomentumOptimizer(learning_rate, momentum=0.9, use_nesterov=True)
            update_step = opt.apply_gradients(zip(clipped_gradients, params))
        return update_step
    
    def _compute_per(self, predicted_ids, labels_out, labels_eos_id, phn_61, phn_39, mapping):
        
        phn_61_tensor = tf.constant(phn_61, dtype=tf.string)
        phn_39_tensor = tf.constant(phn_39, dtype=tf.string)
        mapping_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(list(mapping.keys()), list(mapping.values())), default_value='')
        self.mapping_table_init = mapping_table.init
        
        def map_to_reduced_phn(p):
            val = mapping_table.lookup(phn_61_tensor[p])
            f1 = lambda: tf.to_int32(tf.reduce_min(tf.where(tf.equal(val, phn_39_tensor))))
            f2 = lambda: tf.to_int32(tf.reduce_min(tf.where(tf.equal(phn_61_tensor[p], phn_39_tensor))))
            return tf.cond(tf.not_equal(val, ''), f1, f2)
        
        indices = tf.to_int64(tf.where(tf.logical_and(tf.not_equal(predicted_ids[:,:,0], -1), tf.not_equal(predicted_ids[:,:,0], labels_eos_id))))
        vals = tf.to_int32(tf.gather_nd(predicted_ids[:,:,0], indices))
        shape = tf.to_int64(tf.shape(predicted_ids[:,:,0]))
        decoded_sparse = tf.SparseTensor(indices, vals, shape)
        labels_out_sparse = self._get_sparse_tensor(labels_out, labels_eos_id)
        
        decoded_reduced = tf.SparseTensor(decoded_sparse.indices, tf.map_fn(map_to_reduced_phn, decoded_sparse.values), decoded_sparse.dense_shape)
        labels_out_reduced = tf.SparseTensor(labels_out_sparse.indices, tf.map_fn(map_to_reduced_phn, labels_out_sparse.values), labels_out_sparse.dense_shape)
        
        return tf.reduce_sum(tf.edit_distance(decoded_reduced, labels_out_reduced, normalize=False)) , tf.to_float(tf.size(labels_out_reduced.values))

def optimistic_restore(session, save_file):
    reader = tf.train.NewCheckpointReader(save_file)
    saved_shapes = reader.get_variable_to_shape_map()
    var_names = sorted([(var.name, var.name.split(':')[0]) for var in tf.global_variables()
            if var.name.split(':')[0] in saved_shapes])
    restore_vars = []
    name2var = dict(zip(map(lambda x:x.name.split(':')[0], tf.global_variables()), tf.global_variables()))
    with tf.variable_scope('', reuse=True):
        for var_name, saved_var_name in var_names:
            curr_var = name2var[saved_var_name]
            var_shape = curr_var.get_shape().as_list()
            if var_shape == saved_shapes[saved_var_name]:
                restore_vars.append(curr_var)
    saver = tf.train.Saver(restore_vars)
    saver.restore(session, save_file)

In [None]:
train_graph = tf.Graph()
dev_graph = tf.Graph()
test_graph = tf.Graph()

with train_graph.as_default():
    train_model = Model(batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id, labels_eos_id,
                 learning_rate, optimizer='adam', file_type=TRAIN_FILE, model_type='train')
    initializer = tf.global_variables_initializer()

with dev_graph.as_default():
    dev_model = Model(batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id, labels_eos_id,
                 file_type=DEV_FILE, model_type='eval')
with test_graph.as_default():
    test_model = Model(batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id, labels_eos_id,
                 beam_width=10, phn_61=phn_61, phn_39=phn_39, mapping=mapping, file_type=TEST_FILE, model_type='infer')
    
train_sess = tf.Session(graph=train_graph)
train_sess.run(initializer)

dev_sess = tf.Session(graph=dev_graph)

test_sess = tf.Session(graph=test_graph)
test_sess.run(test_model.mapping_table_init)

#with train_graph.as_default():
#    optimistic_restore(train_sess, './model/seq2seq_debug/ckpt-19')
#train_model.saver.restore(train_sess, './model/seq2seq_debug/ckpt-32')

for epoch in range(epochs):
#for epoch in range(25,100):
    train_sess.run(train_model.iterator_initializer)
    train_loss = []
    start = time.time()
    while True:
        try:
            _, cost = train_sess.run([train_model.update_step, train_model.loss], 
                                     feed_dict={train_model.keep_prob: 0.6, train_model.training: True})
            #train_model.update_step.run(session=train_sess, feed_dict={train_model.keep_prob: 1.0})
            train_loss.append(cost)
        except tf.errors.OutOfRangeError:
            end = time.time()
            log = "Epoch {}/{}, \ntrain_loss={:.3f}, time={:.0f}s"
            print(log.format(epoch+1, epochs, np.mean(train_loss), end-start))
            checkpoint_path = train_model.saver.save(train_sess, checkpoints_path, global_step=epoch+1)
            
            dev_model.saver.restore(dev_sess, checkpoint_path)
            dev_sess.run(dev_model.iterator_initializer)
            dev_loss = []
            start = time.time()
            while True:
                try:
                    cost = dev_sess.run(dev_model.loss, feed_dict={dev_model.keep_prob: 1.0, dev_model.training: False})
                    dev_loss.append(cost)
                except tf.errors.OutOfRangeError:
                    end = time.time()
                    log = "\tdev_loss={:.3f}, time={:.0f}s"
                    print(log.format(np.mean(dev_loss), end-start))
                    break
            
            if epoch < 20:
                break
            test_model.saver.restore(test_sess, checkpoint_path)
            test_sess.run(test_model.iterator_initializer)
            test_per = []
            test_seq_len = []
            start = time.time()
            while True:
                try:
                    _per, seq_len = test_sess.run(test_model.per, feed_dict={test_model.keep_prob: 1.0, test_model.training: False})
                    test_per.append(_per)
                    test_seq_len.append(seq_len)
                except tf.errors.OutOfRangeError:
                    end = time.time()
                    log = '\t\ttest_per={:.3f}, time={:.0f}s'
                    print(log.format(sum(test_per)/sum(test_seq_len), end-start))
                    break
            
            break
            
train_sess.close()
dev_sess.close()
test_sess.close()

In [5]:
restored_ckpt_path = './model/seq2seq_debug/ckpt-46'
ft_learning_rate = 0.0001
ft_batch_size = 32
ft_epochs = 100

In [6]:
ft_train_graph = tf.Graph()
ft_dev_graph = tf.Graph()
ft_test_graph = tf.Graph()

with ft_train_graph.as_default():
    ft_train_model = Model(ft_batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id,
                           labels_eos_id, ft_learning_rate, optimizer='adam', file_type=TRAIN_FILE, model_type='train')
    initializer = tf.global_variables_initializer()
    
with ft_dev_graph.as_default():
    ft_dev_model = Model(ft_batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id, 
                         labels_eos_id, file_type=DEV_FILE, model_type='eval')

with ft_test_graph.as_default():
    ft_test_model = Model(ft_batch_size, num_unit_encoder, num_unit_decoder, n_hidden_layer, feats_dim, num_classes, labels_sos_id,
                          labels_eos_id, beam_width=10, phn_61=phn_61, phn_39=phn_39, mapping=mapping, file_type=TEST_FILE, model_type='infer')    
    

ft_train_sess = tf.Session(graph=ft_train_graph)
ft_train_sess.run(initializer)
ft_dev_sess = tf.Session(graph=ft_dev_graph)
ft_test_sess = tf.Session(graph=ft_test_graph)
ft_test_sess.run(ft_test_model.mapping_table_init)

ft_train_model.saver.restore(ft_train_sess, restored_ckpt_path)

for epoch in range(ft_epochs):
    ft_train_sess.run(ft_train_model.iterator_initializer)
    ft_train_loss  = []
    start = time.time()
    while True:
        try:
            '''
            for v in tf.trainable_variables():
                if 'cell' in v.name and 'kernel' in v.name:
                    v += tf.random_normal(shape=v.shape, stddev=0.1)
                    #v += tf.random_normal(shape=v.shape, stddev=0.075)
            '''
            _, cost = ft_train_sess.run([ft_train_model.update_step, ft_train_model.loss], feed_dict={ft_train_model.keep_prob: 0.6,
                                                                                                     ft_train_model.training: True})
            ft_train_loss.append(cost)
        except tf.errors.OutOfRangeError:
            end = time.time()
            log = "Epoch {}/{}: \ntrain_loss={:.3f}, time = {:.0f}s"
            print(log.format(epoch+1, ft_epochs, np.mean(ft_train_loss), end-start))
            ft_checkpoint_path = ft_train_model.saver.save(ft_train_sess, ft_checkpoints_path, global_step=epoch+1)
            
            ft_dev_model.saver.restore(ft_dev_sess, ft_checkpoint_path)
            ft_dev_sess.run(ft_dev_model.iterator_initializer)
            ft_dev_loss = []
            start = time.time()
            while True:
                try:
                    cost = ft_dev_sess.run(ft_dev_model.loss, feed_dict={ft_dev_model.keep_prob: 1.0, ft_dev_model.training: False})
                    ft_dev_loss.append(cost)
                except tf.errors.OutOfRangeError:
                    end = time.time()
                    log = "\tdev_loss={:.3f}, time = {:.0f}s"
                    print(log.format(np.mean(ft_dev_loss), end-start))
                    break
                
            ft_test_model.saver.restore(ft_test_sess, ft_checkpoint_path)
            ft_test_sess.run(ft_test_model.iterator_initializer)
            ft_test_per = []
            ft_test_seq_len = []
            start = time.time()
            while True:
                try:
                    _per, seq_len = ft_test_sess.run(ft_test_model.per, feed_dict={ft_test_model.keep_prob: 1.0, ft_test_model.training: False})
                    ft_test_per.append(_per)
                    ft_test_seq_len.append(seq_len)
                except tf.errors.OutOfRangeError:
                    end = time.time()
                    log = '\t\ttest_per={:.3f}, time={:.0f}s'
                    print(log.format(sum(ft_test_per)/sum(ft_test_seq_len), end-start))
                    break
            # go to netxt epoch
            break
        
        
ft_train_sess.close()
ft_dev_sess.close()
ft_test_sess.close()

INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/ckpt-46
Epoch 1/100: 
train_loss=18.329, time = 133s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-1
	dev_loss=23.792, time = 9s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-1
		test_per=0.171, time=12s
Epoch 2/100: 
train_loss=17.779, time = 124s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-2
	dev_loss=23.831, time = 7s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-2
		test_per=0.169, time=11s
Epoch 3/100: 
train_loss=17.436, time = 124s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-3
	dev_loss=24.088, time = 7s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-3
		test_per=0.172, time=11s
Epoch 4/100: 
train_loss=17.189, time = 123s
INFO:tensorflow:Restoring parameters from ./model/seq2seq_debug/finetunning/ckpt-4
	d

KeyboardInterrupt: 