In [1]:
import librosa
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm

In [2]:
wav_files = [f for f in os.listdir('./data') if f.endswith('.wav')]
text_files = [f for f in os.listdir('./data') if f.endswith('.txt')]

In [3]:
inputs, targets = [], []
for (wav_file, text_file) in tqdm(zip(wav_files, text_files), total = len(wav_files),ncols=80):
    path = './data/' + wav_file
    try:
        y, sr = librosa.load(path, sr = None)
    except:
        continue
    inputs.append(
        librosa.feature.mfcc(
            y = y, sr = sr, n_mfcc = 40, hop_length = int(0.05 * sr)
        ).T
    )
    with open('./data/' + text_file) as f:
        targets.append(f.read())

100%|███████████████████████████████████████| 2800/2800 [00:46<00:00, 59.81it/s]


In [4]:
inputs[0].shape

(43, 40)

In [5]:
inputs = tf.keras.preprocessing.sequence.pad_sequences(
    inputs, dtype = 'float32', padding = 'post'
)

chars = list(set([c for target in targets for c in target]))
num_classes = len(chars) + 2

idx2char = {idx + 1: char for idx, char in enumerate(chars)}
idx2char[0] = '<PAD>'
char2idx = {char: idx for idx, char in idx2char.items()}

targets = [[char2idx[c] for c in target] for target in targets]

In [6]:
num_classes

27

In [7]:
def encoder_block(inp, n_hidden, filter_size):
    inp = tf.expand_dims(inp, 2)
    inp = tf.pad(inp, [[0, 0], [(filter_size[0]-1)//2, (filter_size[0]-1)//2], [0, 0], [0, 0]])
    conv = tf.layers.conv2d(inp, n_hidden, filter_size, padding="VALID", activation=None)
    conv = tf.squeeze(conv, 2)
    return conv

def glu(x):
    return tf.multiply(x[:, :, :tf.shape(x)[2]//2], tf.sigmoid(x[:, :, tf.shape(x)[2]//2:]))

def layer(inp, conv_block, kernel_width, n_hidden, residual=None):
    z = conv_block(inp, n_hidden, (kernel_width, 1))
    return glu(z) + (residual if residual is not None else 0)

def pad_second_dim(x, desired_size):
    padding = tf.tile([[0]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1]], 0))
    return tf.concat([x, padding], 1)

class Model:
    def __init__(
        self,
        num_layers,
        size_layers,
        learning_rate,
        num_features,
        dropout = 1.0,
    ):
        self.X = tf.placeholder(tf.float32, [None, None, num_features])
        self.label = tf.placeholder(tf.int32, [None, None])
        self.Y_seq_len = tf.placeholder(tf.int32, [None])
        self.Y = tf.sparse_placeholder(tf.int32)
        seq_lens = tf.count_nonzero(
            tf.reduce_sum(self.X, -1), 1, dtype = tf.int32
        )
        batch_size = tf.shape(self.X)[0]
        
        def cells(reuse = False):
            return tf.contrib.rnn.DropoutWrapper(
                tf.nn.rnn_cell.LSTMCell(
                    size_layers,
                    initializer = tf.orthogonal_initializer(),
                    reuse = reuse,
                ),
                state_keep_prob = dropout,
                output_keep_prob = dropout,
            )
        def attention(encoder_out, seq_len, reuse=False):
            attention_mechanism = tf.contrib.seq2seq.BahdanauAttention(num_units = size_layers, 
                                                                    memory = encoder_out,
                                                                    memory_sequence_length = seq_len)
            return tf.contrib.seq2seq.AttentionWrapper(
            cell = tf.nn.rnn_cell.MultiRNNCell([cells(reuse) for _ in range(num_layers)]), 
                attention_mechanism = attention_mechanism,
                attention_layer_size = size_layers)
        
        encoder_embedded = self.X
        encoder_embedded = tf.layers.conv1d(encoder_embedded, size_layers, 1)
        e = tf.identity(encoder_embedded)
        for i in range(num_layers * 2):
            z = layer(encoder_embedded, encoder_block, 3, size_layers * 2, encoder_embedded)
            encoder_embedded = z
        
        encoder_output, output_memory = z, z + e
        print(encoder_output, output_memory)
        
        init_state = tf.reduce_mean(output_memory,axis=1)
        encoder_state = tuple(tf.nn.rnn_cell.LSTMStateTuple(c=init_state, h=init_state) for _ in range(num_layers))
        main = tf.strided_slice(self.X, [0, 0, 0], [batch_size, -1, num_features], [1, 1, 1])
        decoder_input = tf.concat([tf.fill([batch_size, 1, num_features], 0.0), main], 1)
        decoder_cell = attention(encoder_output, seq_lens)
        dense_layer = tf.layers.Dense(num_classes)
        
        training_helper = tf.contrib.seq2seq.TrainingHelper(
                inputs = decoder_input,
                sequence_length = seq_lens,
                time_major = False)
        training_decoder = tf.contrib.seq2seq.BasicDecoder(
                cell = decoder_cell,
                helper = training_helper,
                initial_state = decoder_cell.zero_state(batch_size, tf.float32).clone(cell_state=encoder_state),
                output_layer = dense_layer)
        training_decoder_output, _, _ = tf.contrib.seq2seq.dynamic_decode(
                decoder = training_decoder,
                impute_finished = True,
                maximum_iterations = tf.reduce_max(seq_lens))
        self.seq_lens = seq_lens
        
        logits = training_decoder_output.rnn_output
        time_major = tf.transpose(logits, [1, 0, 2])
        self.time_major = time_major
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(time_major, seq_lens)
        decoded = tf.to_int32(decoded[0])
        self.preds = tf.sparse.to_dense(decoded)
        self.cost = tf.reduce_mean(
            tf.nn.ctc_loss(
                self.Y,
                time_major,
                seq_lens
            )
        )
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        
        preds = self.preds[:, :tf.reduce_max(self.Y_seq_len)]
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        preds = pad_second_dim(preds, tf.reduce_max(self.Y_seq_len))
        y_t = tf.cast(preds, tf.int32)
        self.prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(self.label, masks)
        self.mask_label = mask_label
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [8]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

size_layers = 128
learning_rate = 1e-4
num_layers = 2
batch_size = 32
epoch = 50

model = Model(num_layers, size_layers, learning_rate, inputs.shape[2])
sess.run(tf.global_variables_initializer())

Tensor("add_3:0", shape=(?, ?, 128), dtype=float32) Tensor("add_4:0", shape=(?, ?, 128), dtype=float32)
Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.


In [9]:
def pad_sentence_batch(sentence_batch, pad_int):
    padded_seqs = []
    seq_lens = []
    max_sentence_len = max([len(sentence) for sentence in sentence_batch])
    for sentence in sentence_batch:
        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
        seq_lens.append(len(sentence))
    return padded_seqs, seq_lens

def sparse_tuple_from(sequences, dtype=np.int32):
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

In [10]:
for e in range(epoch):
    pbar = tqdm(
        range(0, len(inputs), batch_size), desc = 'minibatch loop')
    total_cost, total_accuracy = 0, 0
    for i in pbar:
        batch_x = inputs[i : min(i + batch_size, len(inputs))]
        y = targets[i : min(i + batch_size, len(inputs))]
        batch_y = sparse_tuple_from(y)
        batch_label, batch_len = pad_sentence_batch(y, 0)
        _, cost, accuracy = sess.run(
            [model.optimizer, model.cost, model.accuracy],
            feed_dict = {model.X: batch_x, model.Y: batch_y, 
                         model.label: batch_label, model.Y_seq_len: batch_len},
        )
        total_cost += cost
        total_accuracy += accuracy
        pbar.set_postfix(cost = cost, accuracy = accuracy)
    total_cost /= (len(inputs) / batch_size)
    total_accuracy /= (len(inputs) / batch_size)
    print('epoch %d, average cost %f, average accuracy %f'%(e + 1, total_cost, total_accuracy))

minibatch loop: 100%|██████████| 88/88 [00:38<00:00,  2.55it/s, accuracy=0.0296, cost=239] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 1, average cost inf, average accuracy 0.028383


minibatch loop: 100%|██████████| 88/88 [00:38<00:00,  2.64it/s, accuracy=0.0778, cost=147]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 2, average cost 182.874638, average accuracy 0.062438


minibatch loop: 100%|██████████| 88/88 [00:39<00:00,  2.37it/s, accuracy=0.0889, cost=110]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 3, average cost 125.555172, average accuracy 0.103079


minibatch loop: 100%|██████████| 88/88 [00:42<00:00,  2.20it/s, accuracy=0.167, cost=82.8] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 4, average cost 98.401824, average accuracy 0.131852


minibatch loop: 100%|██████████| 88/88 [00:43<00:00,  2.29it/s, accuracy=0.189, cost=66.5] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 5, average cost 79.149195, average accuracy 0.149367


minibatch loop: 100%|██████████| 88/88 [00:42<00:00,  2.34it/s, accuracy=0.196, cost=61]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 6, average cost 66.437343, average accuracy 0.168419


minibatch loop: 100%|██████████| 88/88 [00:44<00:00,  2.22it/s, accuracy=0.259, cost=43.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 7, average cost 54.870988, average accuracy 0.197650


minibatch loop: 100%|██████████| 88/88 [00:45<00:00,  2.31it/s, accuracy=0.337, cost=36.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 8, average cost 45.923787, average accuracy 0.245022


minibatch loop: 100%|██████████| 88/88 [00:45<00:00,  2.19it/s, accuracy=0.448, cost=33.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 9, average cost 40.820036, average accuracy 0.292798


minibatch loop: 100%|██████████| 88/88 [00:46<00:00,  2.03it/s, accuracy=0.393, cost=29.6]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 10, average cost 36.860355, average accuracy 0.332306


minibatch loop: 100%|██████████| 88/88 [00:46<00:00,  2.15it/s, accuracy=0.559, cost=25.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 11, average cost 34.004653, average accuracy 0.387382


minibatch loop: 100%|██████████| 88/88 [00:47<00:00,  2.17it/s, accuracy=0.7, cost=25.1]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 12, average cost 30.691413, average accuracy 0.464070


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.04it/s, accuracy=0.696, cost=24]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 13, average cost 29.090096, average accuracy 0.509517


minibatch loop: 100%|██████████| 88/88 [00:47<00:00,  2.01it/s, accuracy=0.763, cost=23.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 14, average cost 27.296883, average accuracy 0.554261


minibatch loop: 100%|██████████| 88/88 [00:47<00:00,  1.98it/s, accuracy=0.726, cost=22.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 15, average cost 26.197914, average accuracy 0.592831


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  1.91it/s, accuracy=0.763, cost=20.5]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 16, average cost 25.091822, average accuracy 0.612823


minibatch loop: 100%|██████████| 88/88 [00:47<00:00,  1.93it/s, accuracy=0.77, cost=19.6] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 17, average cost 24.267122, average accuracy 0.639275


minibatch loop: 100%|██████████| 88/88 [00:47<00:00,  2.14it/s, accuracy=0.767, cost=19.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 18, average cost 23.328095, average accuracy 0.665143


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  1.86it/s, accuracy=0.781, cost=18.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 19, average cost 22.586759, average accuracy 0.681174


minibatch loop: 100%|██████████| 88/88 [00:47<00:00,  2.11it/s, accuracy=0.781, cost=19]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 20, average cost 21.991165, average accuracy 0.696126


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  1.95it/s, accuracy=0.767, cost=19.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 21, average cost 21.561563, average accuracy 0.702606


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.09it/s, accuracy=0.741, cost=18.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 22, average cost 20.938929, average accuracy 0.715464


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.07it/s, accuracy=0.781, cost=17.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 23, average cost 20.196120, average accuracy 0.726658


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.03it/s, accuracy=0.778, cost=17.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 24, average cost 19.716430, average accuracy 0.732824


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  2.03it/s, accuracy=0.774, cost=17.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 25, average cost 19.475602, average accuracy 0.737113


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  1.89it/s, accuracy=0.785, cost=17.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 26, average cost 19.194826, average accuracy 0.745778


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  1.94it/s, accuracy=0.781, cost=16.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 27, average cost 18.849482, average accuracy 0.751265


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.12it/s, accuracy=0.778, cost=16.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 28, average cost 18.642612, average accuracy 0.751138


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.10it/s, accuracy=0.781, cost=16]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 29, average cost 18.420204, average accuracy 0.753915


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  2.00it/s, accuracy=0.778, cost=15.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 30, average cost 18.102007, average accuracy 0.759883


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  1.90it/s, accuracy=0.785, cost=14.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 31, average cost 18.056934, average accuracy 0.760919


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.06it/s, accuracy=0.778, cost=16]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 32, average cost 17.869992, average accuracy 0.763024


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.07it/s, accuracy=0.774, cost=14.6]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 33, average cost 17.679259, average accuracy 0.761561


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.00it/s, accuracy=0.793, cost=15.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 34, average cost 17.424022, average accuracy 0.764196


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.08it/s, accuracy=0.789, cost=15.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 35, average cost 17.180866, average accuracy 0.764687


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  1.92it/s, accuracy=0.789, cost=15.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 36, average cost 17.245120, average accuracy 0.766043


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  2.07it/s, accuracy=0.796, cost=13.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 37, average cost 16.944607, average accuracy 0.769750


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.05it/s, accuracy=0.789, cost=14.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 38, average cost 16.648799, average accuracy 0.771317


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  1.98it/s, accuracy=0.785, cost=15]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 39, average cost 16.680096, average accuracy 0.770642


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  1.92it/s, accuracy=0.8, cost=13.9]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 40, average cost 16.401477, average accuracy 0.772111


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.07it/s, accuracy=0.781, cost=14.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 41, average cost 16.116820, average accuracy 0.773811


minibatch loop: 100%|██████████| 88/88 [00:49<00:00,  1.91it/s, accuracy=0.8, cost=13.6]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 42, average cost 16.076215, average accuracy 0.772236


minibatch loop: 100%|██████████| 88/88 [00:48<00:00,  2.12it/s, accuracy=0.793, cost=12.8]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 43, average cost 15.954416, average accuracy 0.773870


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.36it/s, accuracy=0.781, cost=13.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 44, average cost 15.959549, average accuracy 0.774471


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.42it/s, accuracy=0.796, cost=12.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 45, average cost 15.813611, average accuracy 0.775345


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.36it/s, accuracy=0.781, cost=12.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 46, average cost 15.528131, average accuracy 0.777428


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.38it/s, accuracy=0.793, cost=12.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 47, average cost 15.357984, average accuracy 0.778509


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.44it/s, accuracy=0.796, cost=12.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 48, average cost 15.313599, average accuracy 0.778225


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.41it/s, accuracy=0.778, cost=13.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 49, average cost 15.281128, average accuracy 0.777808


minibatch loop: 100%|██████████| 88/88 [00:40<00:00,  2.41it/s, accuracy=0.789, cost=13.4]

epoch 50, average cost 15.874901, average accuracy 0.773099





In [11]:
import random

random_index = random.randint(0, len(targets) - 1)
batch_x = inputs[random_index : random_index + 1]
print(
    'real:',
    ''.join(
        [idx2char[no] for no in targets[random_index : random_index + 1][0]]
    ),
)
batch_y = sparse_tuple_from(targets[random_index : random_index + 1])
pred = sess.run(model.preds, feed_dict = {model.X: batch_x})[0]
print('predicted:', ''.join([idx2char[no] for no in pred]))

real: say the word laud
predicted: say the word al
