* GRUCell, 1-encoder-layer

In [1]:
import tensorflow as tf
from tensorflow.python.layers import core as layers_core
import numpy as np
import os, time

In [2]:
phn_61 = ['aa', 'ae', 'ah', 'ao', 'aw', 'ax', 'ax-h', 'axr', 'ay', 'b', 'bcl', 'ch', 'd', 'dcl', 'dh', 'dx', 'eh', 'el', 'em', 'en', 'eng', 'epi', 'er', 'ey', 'f', 'g', 'gcl', 'h#', 'hh', 'hv', 'ih', 'ix', 'iy', 'jh', 'k', 'kcl', 'l', 'm', 'n', 'ng', 'nx', 'ow', 'oy', 'p', 'pau', 'pcl', 'q', 'r', 's', 'sh', 't', 'tcl', 'th', 'uh', 'uw', 'ux', 'v', 'w', 'y', 'z', 'zh']
phn_39 = ['ae', 'ao', 'aw', 'ax', 'ay', 'b', 'ch', 'd', 'dh', 'dx', 'eh', 'er', 'ey', 'f', 'g', 'h#', 'hh', 'ix', 'iy', 'jh', 'k', 'l', 'm', 'n', 'ng', 'ow', 'oy', 'p', 'r', 's', 't', 'th', 'uh', 'uw', 'v', 'w', 'y', 'z', 'zh']
mapping = {'ah': 'ax', 'ax-h': 'ax', 'ux': 'uw', 'aa': 'ao', 'ih': 'ix', 'axr': 'er', 'el': 'l', 'em': 'm', 'en': 'n', 'nx': 'n', 'eng': 'ng', 'sh': 'zh', 'hv': 'hh', 'bcl': 'h#', 'pcl': 'h#', 'dcl': 'h#', 'tcl': 'h#', 'gcl': 'h#', 'kcl': 'h#', 'q': 'h#', 'epi': 'h#', 'pau': 'h#'}

TRAIN_FILE = '../data/fbank/train.tfrecords'
DEV_FILE = '../data/fbank/dev.tfrecords'
TEST_FILE = '../data/fbank/test.tfrecords'
checkpoints_path = '../model/seq2seq_basic/ckpt'

feat_type = 'fbank'
feats_dim = 39 if feat_type=='mfcc' else 123
labels_sos_id = len(phn_61)
labels_eos_id = len(phn_61) + 1
num_classes = len(phn_61) + 2

num_encoder = 256
num_decoder = 256
learning_rate = 0.001
n_encoder_layer = 1
beam_width = 10
batch_size = 32
epochs = 100

In [3]:
class Model(object):
    def __init__(self, data_path, model_type):
        self.data_path = data_path
        self.model_type = model_type

        self._get_data()
        self._build()
    
    def _get_data(self):
        iterator = self._get_iterator()
        self.iterator_initializer = iterator.initializer
        
        batched_data = iterator.get_next()
        self.features = batched_data[0]
        self.labels_in = batched_data[1]
        self.labels_out = batched_data[2]
        self.feats_seq_len = tf.to_int32(batched_data[3])
        self.labels_in_seq_len = tf.to_int32(batched_data[4])
        
    def _build(self):
        self.keep_prob = tf.placeholder(tf.float32)
        decoded = self._compute_dynamic_decode()
        
        if self.model_type == 'train':
            self.logits = decoded
            self.loss = self._compute_loss()
            self.update_step = self._get_update_step()
        elif self.model_type == 'eval':
            self.logits = decoded
            self.loss = self._compute_loss()
        elif self.model_type == 'infer':
            self.predicted_ids = decoded
            self.per = self._compute_per() # tuple of unnormalized edit distance and seqeucne_length
        else:
            raise Exception('invalid model_type')
            
        self.saver = tf.train.Saver(max_to_keep=10)
        
    def _get_iterator(self):
        dataset = tf.contrib.data.TFRecordDataset(self.data_path)
        context_features = {'feats_seq_len': tf.FixedLenFeature([], dtype=tf.int64),
                           'labels_seq_len': tf.FixedLenFeature([], dtype=tf.int64)}
        sequence_features = {'features': tf.FixedLenSequenceFeature([feats_dim], dtype=tf.float32),
                            'labels': tf.FixedLenSequenceFeature([], dtype=tf.int64)}
        dataset = dataset.map(lambda serialized_example: tf.parse_single_sequence_example(serialized_example,
                                                                                         context_features=context_features,
                                                                                         sequence_features=sequence_features))
        dataset = dataset.map(lambda context, sequence: (sequence['features'], sequence['labels'], 
                                                         context['feats_seq_len'], context['labels_seq_len']))
        dataset = dataset.map(lambda features, labels, feats_seq_len, labels_seq_len: (features, 
                                                        tf.concat(([labels_sos_id], labels),0),
                                                        tf.concat((labels, [labels_eos_id]), 0),
                                                                feats_seq_len, labels_seq_len))
        dataset = dataset.map(lambda features, labels_in, labels_out, feats_seq_len, labels_seq_len: 
                                          (features, labels_in, labels_out, feats_seq_len, tf.size(labels_in, out_type=tf.int64)))
        def batching_func(x):
            return x.padded_batch(batch_size,
                                 padded_shapes=(tf.TensorShape([None, feats_dim]),
                                               tf.TensorShape([None]),
                                               tf.TensorShape([None]),
                                               tf.TensorShape([]),
                                               tf.TensorShape([])),
                                 padding_values=(tf.cast(0, tf.float32),
                                                tf.cast(labels_eos_id, tf.int64),
                                                tf.cast(labels_eos_id, tf.int64),
                                                tf.cast(0, tf.int64),
                                                tf.cast(0, tf.int64)))
        def key_func(features, labels_in, labels_out, feats_seq_len, labels_in_seq_len):
            f0 = lambda: tf.constant(0, tf.int64)
            f1 = lambda: tf.constant(1, tf.int64)
            f2 = lambda: tf.constant(2, tf.int64)
            f3 = lambda: tf.constant(3, tf.int64)
            f4 = lambda: tf.constant(4, tf.int64)
            f5 = lambda: tf.constant(5, tf.int64)
            f6 = lambda: tf.constant(6, tf.int64)
            
            return tf.case([(tf.less_equal(feats_seq_len, 200), f0),
                   (tf.less_equal(feats_seq_len, 250), f1),
                   (tf.less_equal(feats_seq_len, 300), f2),
                   (tf.less_equal(feats_seq_len, 350), f3),
                   (tf.less_equal(feats_seq_len, 400), f4),
                   (tf.less_equal(feats_seq_len, 500), f5)], default=f6)
        
        def reduce_func(bucket_id, windowed_data):
            return batching_func(windowed_data)
        
        if self.model_type=='train':
            dataset = dataset.shuffle(10000)
            batched_dataset = dataset.group_by_window(key_func=key_func, reduce_func=reduce_func, window_size=batch_size)
            batched_dataset = batched_dataset.shuffle(10000)
        else: # eval or infer, no shuffle
            batched_dataset = batching_func(dataset)
            
        return batched_dataset.make_initializable_iterator()
    
    def _compute_encoder_outputs(self):
        cells_fw = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.GRUCell(num_encoder), output_keep_prob=self.keep_prob) for _ in range(n_encoder_layer)]
        cells_bw = [tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.GRUCell(num_encoder), output_keep_prob=self.keep_prob) for _ in range(n_encoder_layer)]
        #cells_fw = [tf.nn.rnn_cell.GRUCell(num_encoder) for _ in range(n_encoder_layer)]
        #cells_bw = [tf.nn.rnn_cell.GRUCell(num_encoder) for _ in range(n_encoder_layer)]
        outputs, state_fw, state_bw = tf.contrib.rnn.stack_bidirectional_dynamic_rnn(cells_fw, cells_bw, 
                                                            inputs=self.features, sequence_length=self.feats_seq_len, dtype=tf.float32)
        
        return outputs, self.feats_seq_len
        '''
        prev_layer = self.features
        prev_seq_len = self.feats_seq_len
        for i, (cell_fw, cell_bw) in enumerate(zip(cells_fw, cells_bw)):
            with tf.variable_scope("cell_%d" % i):
                if i != 0:
                    prev_layer = prev_layer[:,::2,:]
                    prev_seq_len = prev_seq_len // 2
                outputs, (state_fw, state_bw) = tf.nn.bidirectional_dynamic_rnn(cell_fw, cell_bw, prev_layer, sequence_length=prev_seq_len, dtype=tf.float32)
                prev_layer = tf.concat(outputs, 2)
        return prev_layer, prev_seq_len
        '''
    def _get_decoder_cell_and_init_state(self):
        
        num_batch = tf.shape(self.memory)[0]
        decoder_cell = tf.nn.rnn_cell.DropoutWrapper(tf.nn.rnn_cell.GRUCell(num_decoder), output_keep_prob=self.keep_prob)

        if self.model_type == 'infer':
            self.memory = tf.contrib.seq2seq.tile_batch(self.memory, multiplier=beam_width)
            self.mem_seq_len = tf.contrib.seq2seq.tile_batch(self.mem_seq_len, multiplier=beam_width)
            num_batch = num_batch * beam_width

        attention_mechanism = tf.contrib.seq2seq.LuongAttention(num_decoder, self.memory, self.mem_seq_len, scale=False)
        decoder_cell = tf.contrib.seq2seq.AttentionWrapper(decoder_cell, attention_mechanism, attention_layer_size=num_decoder)
        decoder_initial_state = decoder_cell.zero_state(num_batch, tf.float32)
        return (decoder_cell, decoder_initial_state)
    
    def _compute_dynamic_decode(self):
        embedding_decoder = tf.Variable(np.identity(num_classes), dtype=tf.float32, trainable=False)
        decoder_emb_inp = tf.nn.embedding_lookup(embedding_decoder, self.labels_in)
    
        self.memory, self.mem_seq_len = self._compute_encoder_outputs()
        with tf.variable_scope('decoder') as decoder_scope:
            decoder_cell, decoder_initial_state = self._get_decoder_cell_and_init_state()
        
            output_layer = layers_core.Dense(num_classes, use_bias=False)

            if self.model_type in ['train', 'eval']:
                helper = tf.contrib.seq2seq.TrainingHelper(decoder_emb_inp, self.labels_in_seq_len, time_major=False)
                decoder = tf.contrib.seq2seq.BasicDecoder(decoder_cell, helper, initial_state=decoder_initial_state)
                outputs, _, _ = tf.contrib.seq2seq.dynamic_decode(decoder, output_time_major=False, scope=decoder_scope)
                return output_layer(outputs.rnn_output)
            else: # infer
                beam_decoder = tf.contrib.seq2seq.BeamSearchDecoder(decoder_cell, embedding_decoder, 
                                                    tf.fill([tf.shape(self.features)[0]], labels_sos_id), labels_eos_id,
                                                            decoder_initial_state, beam_width, output_layer=output_layer)
                decoded, _, final_seq_len = tf.contrib.seq2seq.dynamic_decode(beam_decoder, maximum_iterations=100,
                                                                              output_time_major=False, scope=decoder_scope)
                return decoded.predicted_ids # shape = [batch, max_time, beam_width]
    
    def _compute_loss(self):
        max_time = tf.shape(self.labels_out)[1]
        target_weights = tf.sequence_mask(self.labels_in_seq_len, max_time, dtype=self.logits.dtype)
        crossent = tf.nn.sparse_softmax_cross_entropy_with_logits(labels=self.labels_out, logits=self.logits)
        return tf.reduce_sum(crossent * target_weights) / tf.to_float(tf.shape(self.logits)[0])
    
    def _get_update_step(self):
        params = tf.trainable_variables()
        gradients = tf.gradients(self.loss, params)
        #clipped_gradients, _ = tf.clip_by_global_norm(gradients, 1.0)
        clipped_gradients = gradients # no clipping
        self.optimizer = tf.train.AdamOptimizer(learning_rate)
        return self.optimizer.apply_gradients(zip(clipped_gradients, params))
    
    def _get_sparse_tensor(self, dense, default):
        indices = tf.to_int64(tf.where(tf.not_equal(dense, default)))
        vals = tf.to_int32(tf.gather_nd(dense, indices))
        shape = tf.to_int64(tf.shape(dense))
        return tf.SparseTensor(indices, vals, shape)
    
    def _compute_per(self):
        '''
        return tuple: (unnormalized edit distance, seqeucne_length),
        it is just sum of batched data
        '''
        phn_61_tensor = tf.constant(phn_61, dtype=tf.string)
        phn_39_tensor = tf.constant(phn_39, dtype=tf.string)
        mapping_table = tf.contrib.lookup.HashTable(tf.contrib.lookup.KeyValueTensorInitializer(list(mapping.keys()), list(mapping.values())), default_value='')
        self.mapping_table_init = mapping_table.init
        
        def map_to_reduced_phn(p):
            val = mapping_table.lookup(phn_61_tensor[p])
            f1 = lambda: tf.to_int32(tf.reduce_min(tf.where(tf.equal(val, phn_39_tensor))))
            f2 = lambda: tf.to_int32(tf.reduce_min(tf.where(tf.equal(phn_61_tensor[p], phn_39_tensor))))
            return tf.cond(tf.not_equal(val, ''), f1, f2)
        
        indices = tf.to_int64(tf.where(tf.logical_and(tf.not_equal(self.predicted_ids[:,:,0], -1), 
                                                      tf.not_equal(self.predicted_ids[:,:,0], labels_eos_id))))
        vals = tf.to_int32(tf.gather_nd(self.predicted_ids[:,:,0], indices))
        shape = tf.to_int64(tf.shape(self.predicted_ids[:,:,0]))
        decoded_sparse = tf.SparseTensor(indices, vals, shape)
        labels_out_sparse = self._get_sparse_tensor(self.labels_out, labels_eos_id)
        
        decoded_reduced = tf.SparseTensor(decoded_sparse.indices, tf.map_fn(map_to_reduced_phn, decoded_sparse.values), decoded_sparse.dense_shape)
        labels_out_reduced = tf.SparseTensor(labels_out_sparse.indices, tf.map_fn(map_to_reduced_phn, labels_out_sparse.values), labels_out_sparse.dense_shape)
        
        return tf.reduce_sum(tf.edit_distance(decoded_reduced, labels_out_reduced, normalize=False)) , tf.to_float(tf.size(labels_out_reduced.values))

In [4]:
train_graph = tf.Graph()
dev_graph = tf.Graph()
test_graph = tf.Graph()

with train_graph.as_default():
    train_model = Model(data_path=TRAIN_FILE, model_type='train')
    initializer = tf.global_variables_initializer()

with dev_graph.as_default():
    dev_model = Model(data_path=DEV_FILE, model_type='eval')
with test_graph.as_default():
    test_model = Model(data_path=TEST_FILE, model_type='infer')
    
train_sess = tf.Session(graph=train_graph)
train_sess.run(initializer)
#train_model.saver.restore(train_sess, '../model/seq2seq_basic/ckpt-23')

dev_sess = tf.Session(graph=dev_graph)

test_sess = tf.Session(graph=test_graph)
test_sess.run(test_model.mapping_table_init)

for epoch in range(epochs):
    train_sess.run(train_model.iterator_initializer)
    train_loss = []
    start = time.time()
    while True:
        try:
            _, loss = train_sess.run([train_model.update_step, train_model.loss], feed_dict={train_model.keep_prob: 0.6})
            train_loss.append(loss)
        except tf.errors.OutOfRangeError:
            end = time.time()
            log = "Epoch {}/{}, \ntrain_loss={:.3f}, time={:.0f}s"
            print(log.format(epoch+1, epochs, np.mean(train_loss), end-start))
            if not os.path.isdir(checkpoints_path):
                os.makedirs(checkpoints_path)
            saved_ckpt_path = train_model.saver.save(train_sess, checkpoints_path, global_step=epoch+1)
            
            dev_model.saver.restore(dev_sess, saved_ckpt_path)
            dev_sess.run(dev_model.iterator_initializer)
            start = time.time()
            dev_loss = []
            while True:
                try:
                    loss = dev_sess.run(dev_model.loss, feed_dict={dev_model.keep_prob: 1.0})
                    dev_loss.append(loss)
                except tf.errors.OutOfRangeError:
                    end = time.time()
                    log = "\tdev_loss={:.3f}, time={:.0f}s"
                    print(log.format(np.mean(dev_loss), end-start))
                    break
            
            if epoch < 18: # skip test per calculation to save time
                break
            test_model.saver.restore(test_sess, saved_ckpt_path)
            test_sess.run(test_model.iterator_initializer)
            test_unnormed_edit_dist = []
            test_seq_len = []
            start = time.time()
            while True:
                try:
                    unnormed_edit_dist, seq_len = test_sess.run(test_model.per, feed_dict={test_model.keep_prob: 1.0})
                    test_unnormed_edit_dist.append(unnormed_edit_dist)
                    test_seq_len.append(seq_len)
                except tf.errors.OutOfRangeError:
                    end = time.time()
                    log = '\t\ttest_per={:.3f}, time={:.0f}s'
                    print(log.format(sum(test_unnormed_edit_dist)/sum(test_seq_len), end-start))
                    break
            # go to next training epoch
            break
            
train_sess.close()
dev_sess.close()
test_sess.close()

Epoch 1/100, 
train_loss=144.029, time=91s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-1
	dev_loss=119.542, time=6s
Epoch 2/100, 
train_loss=114.122, time=92s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-2
	dev_loss=105.913, time=6s
Epoch 3/100, 
train_loss=104.014, time=92s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-3
	dev_loss=92.095, time=6s
Epoch 4/100, 
train_loss=92.753, time=92s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-4
	dev_loss=77.014, time=6s
Epoch 5/100, 
train_loss=80.691, time=93s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-5
	dev_loss=63.020, time=6s
Epoch 6/100, 
train_loss=71.102, time=91s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-6
	dev_loss=56.260, time=6s
Epoch 7/100, 
train_loss=66.692, time=92s
INFO:tensorflow:Restoring parameters from ../model/seq2seq_basic/ckpt-7
	dev_loss=63.845, time=6s
Epoch 8/100, 
t

KeyboardInterrupt: 