In [1]:
import librosa
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm

In [2]:
wav_files = [f for f in os.listdir('./data') if f.endswith('.wav')]
text_files = [f for f in os.listdir('./data') if f.endswith('.txt')]

In [3]:
inputs, targets = [], []
for (wav_file, text_file) in tqdm(zip(wav_files, text_files), total = len(wav_files),ncols=80):
    path = './data/' + wav_file
    try:
        y, sr = librosa.load(path, sr = None)
    except:
        continue
    inputs.append(
        librosa.feature.mfcc(
            y = y, sr = sr, n_mfcc = 40, hop_length = int(0.05 * sr)
        ).T
    )
    with open('./data/' + text_file) as f:
        targets.append(f.read())


100%|██████████████████████████████████████| 2800/2800 [00:12<00:00, 223.24it/s]


In [4]:
inputs = tf.keras.preprocessing.sequence.pad_sequences(
    inputs, dtype = 'float32', padding = 'post'
)

chars = list(set([c for target in targets for c in target]))
num_classes = len(chars) + 2

idx2char = {idx + 1: char for idx, char in enumerate(chars)}
idx2char[0] = '<PAD>'
char2idx = {char: idx for idx, char in idx2char.items()}

targets = [[char2idx[c] for c in target] for target in targets]

In [5]:
def pad_sentence_batch(sentence_batch, pad_int):
    padded_seqs = []
    seq_lens = []
    max_sentence_len = max([len(sentence) for sentence in sentence_batch])
    for sentence in sentence_batch:
        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
        seq_lens.append(len(sentence))
    return padded_seqs, seq_lens

def sparse_tuple_from(sequences, dtype=np.int32):
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

In [6]:
def attention(inputs, attention_size):
    hidden_size = inputs.shape[2].value
    w_omega = tf.Variable(
        tf.random_normal([hidden_size, attention_size], stddev = 0.1)
    )
    b_omega = tf.Variable(tf.random_normal([attention_size], stddev = 0.1))
    u_omega = tf.Variable(tf.random_normal([attention_size], stddev = 0.1))
    with tf.name_scope('v'):
        v = tf.tanh(tf.tensordot(inputs, w_omega, axes = 1) + b_omega)
    vu = tf.tensordot(v, u_omega, axes = 1, name = 'vu')
    alphas = tf.nn.softmax(vu, name = 'alphas')
    output = inputs * tf.expand_dims(alphas, -1)
    return output, alphas

def pad_second_dim(x, desired_size):
    padding = tf.tile([[0]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1]], 0))
    return tf.concat([x, padding], 1)

class Model:
    def __init__(
        self,
        num_layers,
        size_layers,
        learning_rate,
        num_features,
        dropout = 1.0,
    ):
        self.X = tf.placeholder(tf.float32, [None, None, num_features])
        self.Y = tf.sparse_placeholder(tf.int32)
        seq_lens = tf.count_nonzero(
            tf.reduce_sum(self.X, -1), 1, dtype = tf.int32
        )
        self.label = tf.placeholder(tf.int32, [None, None])
        self.Y_seq_len = tf.placeholder(tf.int32, [None])

        def cells(size, reuse = False):
            return tf.contrib.rnn.DropoutWrapper(
                tf.nn.rnn_cell.LSTMCell(
                    size,
                    initializer = tf.orthogonal_initializer(),
                    reuse = reuse,
                ),
                state_keep_prob = dropout,
                output_keep_prob = dropout,
            )

        features = self.X
        for n in range(num_layers):
            (out_fw, out_bw), (
                state_fw,
                state_bw,
            ) = tf.nn.bidirectional_dynamic_rnn(
                cell_fw = cells(size_layers),
                cell_bw = cells(size_layers),
                inputs = features,
                sequence_length = seq_lens,
                dtype = tf.float32,
                scope = 'bidirectional_rnn_%d' % (n),
            )
            features = tf.concat((out_fw, out_bw), 2)
        
        features, _ = attention(features, size_layers)
        logits = tf.layers.dense(features, num_classes)
        time_major = tf.transpose(logits, [1, 0, 2])
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(time_major, seq_lens)
        decoded = tf.to_int32(decoded[0])
        self.preds = tf.sparse.to_dense(decoded)
        self.cost = tf.reduce_mean(
            tf.nn.ctc_loss(
                self.Y,
                time_major,
                seq_lens,
                ignore_longer_outputs_than_inputs = True,
            )
        )
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        
        preds = self.preds[:, :tf.reduce_max(self.Y_seq_len)]
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        preds = pad_second_dim(preds, tf.reduce_max(self.Y_seq_len))
        y_t = tf.cast(preds, tf.int32)
        self.prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(self.label, masks)
        self.mask_label = mask_label
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [7]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

size_layers = 128
learning_rate = 1e-3
num_layers = 2
batch_size = 32
epoch = 50

model = Model(num_layers, size_layers, learning_rate, inputs.shape[2])
sess.run(tf.global_variables_initializer())

Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.


In [8]:
for e in range(epoch):
    pbar = tqdm(
        range(0, len(inputs), batch_size), desc = 'minibatch loop')
    total_cost, total_accuracy = 0, 0
    for i in pbar:
        batch_x = inputs[i : min(i + batch_size, len(inputs))]
        y = targets[i : min(i + batch_size, len(inputs))]
        batch_y = sparse_tuple_from(y)
        batch_label, batch_len = pad_sentence_batch(y, 0)
        _, cost, accuracy = sess.run(
            [model.optimizer, model.cost, model.accuracy],
            feed_dict = {model.X: batch_x, model.Y: batch_y, 
                         model.label: batch_label, model.Y_seq_len: batch_len},
        )
        total_cost += cost
        total_accuracy += accuracy
        pbar.set_postfix(cost = cost, accuracy = accuracy)
    total_cost /= (len(inputs) / batch_size)
    total_accuracy /= (len(inputs) / batch_size)
    print('epoch %d, average cost %f, average accuracy %f'%(e + 1, total_cost, total_accuracy))

minibatch loop: 100%|██████████| 88/88 [00:21<00:00,  4.95it/s, accuracy=0.113, cost=71.3] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 1, average cost 84.087347, average accuracy 0.128065


minibatch loop: 100%|██████████| 88/88 [00:18<00:00,  5.52it/s, accuracy=0.0657, cost=57.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 2, average cost 64.124234, average accuracy 0.076903


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.61it/s, accuracy=0, cost=51.5]      
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 3, average cost 55.271279, average accuracy 0.018553


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.71it/s, accuracy=0, cost=48.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 4, average cost 51.333191, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.64it/s, accuracy=0, cost=47.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 5, average cost 49.166376, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.64it/s, accuracy=0, cost=46.5]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 6, average cost 47.793606, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.59it/s, accuracy=0, cost=45.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 7, average cost 46.821985, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.63it/s, accuracy=0, cost=45.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 8, average cost 46.077298, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.57it/s, accuracy=0, cost=44.8]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 9, average cost 45.482462, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.55it/s, accuracy=0, cost=44.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 10, average cost 44.985425, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.55it/s, accuracy=0, cost=44]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 11, average cost 44.547218, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.53it/s, accuracy=0, cost=43.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 12, average cost 44.142802, average accuracy 0.000000


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.77it/s, accuracy=0.0438, cost=42.5] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 13, average cost 43.500052, average accuracy 0.020594


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.60it/s, accuracy=0.0292, cost=41.8]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 14, average cost 42.603858, average accuracy 0.043385


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.62it/s, accuracy=0.0328, cost=41.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 15, average cost 41.920044, average accuracy 0.029245


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.62it/s, accuracy=0.0292, cost=40.4] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 16, average cost 41.259851, average accuracy 0.028186


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.66it/s, accuracy=0.0401, cost=39.8]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 17, average cost 40.573918, average accuracy 0.034699


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.78it/s, accuracy=0.0657, cost=39.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s, accuracy=0.0494, cost=39.9]

epoch 18, average cost 39.834039, average accuracy 0.051539


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.70it/s, accuracy=0.109, cost=38.3] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 19, average cost 39.098471, average accuracy 0.084964


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.72it/s, accuracy=0.0912, cost=37.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 20, average cost 38.377175, average accuracy 0.095664


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.86it/s, accuracy=0.113, cost=37]   
minibatch loop:   1%|          | 1/88 [00:00<00:17,  5.03it/s, accuracy=0.0987, cost=37.7]

epoch 21, average cost 37.766140, average accuracy 0.099332


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.89it/s, accuracy=0.113, cost=36.2] 
minibatch loop:   1%|          | 1/88 [00:00<00:17,  5.04it/s, accuracy=0.0987, cost=37]

epoch 22, average cost 37.012584, average accuracy 0.109027


minibatch loop: 100%|██████████| 88/88 [00:16<00:00,  5.86it/s, accuracy=0.117, cost=35.5] 
minibatch loop:   1%|          | 1/88 [00:00<00:17,  5.05it/s, accuracy=0.112, cost=36.4]

epoch 23, average cost 36.350182, average accuracy 0.110087


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.58it/s, accuracy=0.106, cost=35]   
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 24, average cost 35.740776, average accuracy 0.109507


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.57it/s, accuracy=0.113, cost=34.2] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 25, average cost 35.091595, average accuracy 0.111049


minibatch loop: 100%|██████████| 88/88 [00:17<00:00,  5.53it/s, accuracy=0.106, cost=33.5] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 26, average cost 34.281996, average accuracy 0.109363


minibatch loop: 100%|██████████| 88/88 [00:18<00:00,  5.26it/s, accuracy=0.00365, cost=32.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 27, average cost 33.556701, average accuracy 0.055271


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.17it/s, accuracy=0.0073, cost=32.2] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 28, average cost 32.868566, average accuracy 0.007412


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.09it/s, accuracy=0.0219, cost=31.7] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 29, average cost 32.228317, average accuracy 0.012191


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.08it/s, accuracy=0.0146, cost=31.1] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 30, average cost 31.603533, average accuracy 0.016900


minibatch loop: 100%|██████████| 88/88 [00:20<00:00,  4.91it/s, accuracy=0.0109, cost=30.6] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 31, average cost 31.031603, average accuracy 0.012558


minibatch loop: 100%|██████████| 88/88 [00:20<00:00,  4.86it/s, accuracy=0.00365, cost=30]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 32, average cost 30.460658, average accuracy 0.008797


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.16it/s, accuracy=0.0109, cost=29.4] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 33, average cost 29.887438, average accuracy 0.009084


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.14it/s, accuracy=0.0182, cost=28.7] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 34, average cost 29.306733, average accuracy 0.009947


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.04it/s, accuracy=0.0073, cost=28.2] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 35, average cost 28.700645, average accuracy 0.006643


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  5.01it/s, accuracy=0, cost=27.7]      
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 36, average cost 28.119770, average accuracy 0.004074


minibatch loop: 100%|██████████| 88/88 [00:19<00:00,  4.98it/s, accuracy=0.00365, cost=27.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 37, average cost 27.585168, average accuracy 0.003216


minibatch loop: 100%|██████████| 88/88 [00:20<00:00,  4.80it/s, accuracy=0.00365, cost=26.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 38, average cost 26.975692, average accuracy 0.003697


minibatch loop: 100%|██████████| 88/88 [00:20<00:00,  4.77it/s, accuracy=0.00365, cost=26]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 39, average cost 26.376474, average accuracy 0.003614


minibatch loop: 100%|██████████| 88/88 [00:21<00:00,  4.62it/s, accuracy=0.00365, cost=25.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 40, average cost 25.845143, average accuracy 0.002630


minibatch loop: 100%|██████████| 88/88 [00:21<00:00,  4.62it/s, accuracy=0.00365, cost=24.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 41, average cost 25.368515, average accuracy 0.002047


minibatch loop: 100%|██████████| 88/88 [00:21<00:00,  4.57it/s, accuracy=0.00365, cost=24.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 42, average cost 24.884350, average accuracy 0.001880


minibatch loop: 100%|██████████| 88/88 [00:22<00:00,  4.50it/s, accuracy=0, cost=24]        
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 43, average cost 24.401837, average accuracy 0.001232


minibatch loop: 100%|██████████| 88/88 [00:22<00:00,  4.48it/s, accuracy=0.00365, cost=23.6]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 44, average cost 24.006603, average accuracy 0.001274


minibatch loop: 100%|██████████| 88/88 [00:22<00:00,  4.43it/s, accuracy=0, cost=23.1]      
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 45, average cost 23.583812, average accuracy 0.000501


minibatch loop: 100%|██████████| 88/88 [00:22<00:00,  4.34it/s, accuracy=0, cost=22.8]      
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 46, average cost 23.236513, average accuracy 0.000021


minibatch loop: 100%|██████████| 88/88 [00:23<00:00,  4.31it/s, accuracy=0, cost=22.5]      
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 47, average cost 22.882915, average accuracy 0.000042


minibatch loop: 100%|██████████| 88/88 [00:23<00:00,  4.30it/s, accuracy=0, cost=22.3]      
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 48, average cost 22.563259, average accuracy 0.000021


minibatch loop: 100%|██████████| 88/88 [00:23<00:00,  4.29it/s, accuracy=0, cost=22]        
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 49, average cost 22.258199, average accuracy 0.000042


minibatch loop: 100%|██████████| 88/88 [00:23<00:00,  4.33it/s, accuracy=0, cost=22.4]      

epoch 50, average cost 22.021709, average accuracy 0.000042





In [9]:
import random

random_index = random.randint(0, len(targets) - 1)
batch_x = inputs[random_index : random_index + 1]
print(
    'real:',
    ''.join(
        [idx2char[no] for no in targets[random_index : random_index + 1][0]]
    ),
)
batch_y = sparse_tuple_from(targets[random_index : random_index + 1])
pred = sess.run(model.preds, feed_dict = {model.X: batch_x})[0]
print('predicted:', ''.join([idx2char[no] for no in pred]))

real: say the word hate
predicted: dy he word te
