In [1]:
import librosa
import os
import tensorflow as tf
import numpy as np
from tqdm import tqdm

In [2]:
wav_files = [f for f in os.listdir('./data') if f.endswith('.wav')]
text_files = [f for f in os.listdir('./data') if f.endswith('.txt')]

In [3]:
inputs, targets = [], []
for (wav_file, text_file) in tqdm(zip(wav_files, text_files), total = len(wav_files),ncols=80):
    path = './data/' + wav_file
    try:
        y, sr = librosa.load(path, sr = None)
    except:
        continue
    inputs.append(
        librosa.feature.mfcc(
            y = y, sr = sr, n_mfcc = 100, hop_length = int(0.01 * sr)
        ).T
    )
    with open('./data/' + text_file) as f:
        targets.append(f.read())

100%|███████████████████████████████████████| 2800/2800 [00:54<00:00, 51.13it/s]


In [4]:
inputs = tf.keras.preprocessing.sequence.pad_sequences(
    inputs, dtype = 'float32', padding = 'post'
)
print(inputs.shape)

chars = list(set([c for target in targets for c in target]))
num_classes = len(chars) + 2

idx2char = {idx + 1: char for idx, char in enumerate(chars)}
idx2char[0] = '<PAD>'
char2idx = {char: idx for idx, char in idx2char.items()}

targets = [[char2idx[c] for c in target] for target in targets]

(2800, 299, 100)


In [5]:
def layer_norm(inputs, epsilon=1e-8):
    mean, variance = tf.nn.moments(inputs, [-1], keep_dims=True)
    normalized = (inputs - mean) / (tf.sqrt(variance + epsilon))
    params_shape = inputs.get_shape()[-1:]
    gamma = tf.get_variable('gamma', params_shape, tf.float32, tf.ones_initializer())
    beta = tf.get_variable('beta', params_shape, tf.float32, tf.zeros_initializer())
    return gamma * normalized + beta


def cnn_block(x, dilation_rate, hidden_dim, kernel_size):
    x = layer_norm(x)
    x =  tf.layers.conv1d(inputs = x,
                          filters = hidden_dim,
                          kernel_size = kernel_size,
                          dilation_rate = dilation_rate)
    x = tf.nn.relu(x)
    return x

def pad_second_dim(x, desired_size):
    padding = tf.tile([[0]], tf.stack([tf.shape(x)[0], desired_size - tf.shape(x)[1]], 0))
    return tf.concat([x, padding], 1)

class Model:
    def __init__(
        self,
        num_layers,
        size_layers,
        learning_rate,
        num_features,
        kernel_size = 5
    ):
        self.X = tf.placeholder(tf.float32, [None, None, num_features])
        self.label = tf.placeholder(tf.int32, [None, None])
        self.Y_seq_len = tf.placeholder(tf.int32, [None])
        self.Y = tf.sparse_placeholder(tf.int32)
        batch_size = tf.shape(self.X)[0]
        
        x = tf.layers.conv1d(self.X, size_layers, 1)
        for i in range(num_layers):
            dilation_rate = 3 ** i
            pad_sz = (kernel_size - 1) * dilation_rate
            with tf.variable_scope('block_%d'%i):
                x = cnn_block(x, dilation_rate, size_layers, kernel_size)
                print(x)
                
        seq_lens = tf.count_nonzero(
            tf.reduce_sum(x, -1), 1, dtype = tf.int32
        )
        
        logits = tf.layers.dense(x, num_classes)
        time_major = tf.transpose(logits, [1, 0, 2])
        decoded, log_prob = tf.nn.ctc_beam_search_decoder(time_major, seq_lens)
        self.dense_decoded = tf.sparse_tensor_to_dense(decoded[0], default_value=-1)
        decoded = tf.to_int32(decoded[0])
        self.preds = tf.sparse.to_dense(decoded)
        self.cost = tf.reduce_mean(
            tf.nn.ctc_loss(
                self.Y,
                time_major,
                seq_lens
            )
        )
        self.optimizer = tf.train.AdamOptimizer(
            learning_rate = learning_rate
        ).minimize(self.cost)
        
        preds = self.preds[:, :tf.reduce_max(self.Y_seq_len)]
        masks = tf.sequence_mask(self.Y_seq_len, tf.reduce_max(self.Y_seq_len), dtype=tf.float32)
        preds = pad_second_dim(preds, tf.reduce_max(self.Y_seq_len))
        y_t = tf.cast(preds, tf.int32)
        self.prediction = tf.boolean_mask(y_t, masks)
        mask_label = tf.boolean_mask(self.label, masks)
        self.mask_label = mask_label
        correct_pred = tf.equal(self.prediction, mask_label)
        correct_index = tf.cast(correct_pred, tf.float32)
        self.accuracy = tf.reduce_mean(tf.cast(correct_pred, tf.float32))

In [6]:
tf.reset_default_graph()
sess = tf.InteractiveSession()

size_layers = 128
learning_rate = 1e-4
num_layers = 4
batch_size = 32
epoch = 50

model = Model(num_layers, size_layers, learning_rate, inputs.shape[2])
sess.run(tf.global_variables_initializer())

Tensor("block_0/Relu:0", shape=(?, ?, 128), dtype=float32)
Tensor("block_1/Relu:0", shape=(?, ?, 128), dtype=float32)
Tensor("block_2/Relu:0", shape=(?, ?, 128), dtype=float32)
Tensor("block_3/Relu:0", shape=(?, ?, 128), dtype=float32)
Instructions for updating:
Create a `tf.sparse.SparseTensor` and use `tf.sparse.to_dense` instead.


In [7]:
def pad_sentence_batch(sentence_batch, pad_int):
    padded_seqs = []
    seq_lens = []
    max_sentence_len = max([len(sentence) for sentence in sentence_batch])
    for sentence in sentence_batch:
        padded_seqs.append(sentence + [pad_int] * (max_sentence_len - len(sentence)))
        seq_lens.append(len(sentence))
    return padded_seqs, seq_lens

def sparse_tuple_from(sequences, dtype=np.int32):
    indices = []
    values = []

    for n, seq in enumerate(sequences):
        indices.extend(zip([n] * len(seq), range(len(seq))))
        values.extend(seq)

    indices = np.asarray(indices, dtype=np.int64)
    values = np.asarray(values, dtype=dtype)
    shape = np.asarray([len(sequences), np.asarray(indices).max(0)[1] + 1], dtype=np.int64)

    return indices, values, shape

In [8]:
for e in range(epoch):
    pbar = tqdm(
        range(0, len(inputs), batch_size), desc = 'minibatch loop')
    total_cost, total_accuracy = 0, 0
    for i in pbar:
        batch_x = inputs[i : min(i + batch_size, len(inputs))]
        y = targets[i : min(i + batch_size, len(inputs))]
        batch_y = sparse_tuple_from(y)
        batch_label, batch_len = pad_sentence_batch(y, 0)
        _, cost, accuracy = sess.run(
            [model.optimizer, model.cost, model.accuracy],
            feed_dict = {model.X: batch_x, model.Y: batch_y, 
                         model.label: batch_label, model.Y_seq_len: batch_len},
        )
        total_cost += cost
        total_accuracy += accuracy
        pbar.set_postfix(cost = cost, accuracy = accuracy)
    total_cost /= (len(inputs) / batch_size)
    total_accuracy /= (len(inputs) / batch_size)
    print('epoch %d, average cost %f, average accuracy %f'%(e + 1, total_cost, total_accuracy))

minibatch loop: 100%|██████████| 88/88 [00:56<00:00,  1.59it/s, accuracy=0, cost=45.3]    
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 1, average cost 57.503664, average accuracy 0.000187


minibatch loop: 100%|██████████| 88/88 [00:54<00:00,  1.86it/s, accuracy=0.0444, cost=38.9] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 2, average cost 42.992088, average accuracy 0.015695


minibatch loop: 100%|██████████| 88/88 [00:55<00:00,  1.83it/s, accuracy=0.037, cost=34.4] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 3, average cost 37.210113, average accuracy 0.038870


minibatch loop: 100%|██████████| 88/88 [00:57<00:00,  1.75it/s, accuracy=0.104, cost=29.1] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 4, average cost 32.799288, average accuracy 0.063512


minibatch loop: 100%|██████████| 88/88 [01:00<00:00,  1.66it/s, accuracy=0.144, cost=25.2] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 5, average cost 27.758218, average accuracy 0.106361


minibatch loop: 100%|██████████| 88/88 [01:02<00:00,  1.66it/s, accuracy=0.2, cost=23.8]   
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 6, average cost 25.133538, average accuracy 0.128984


minibatch loop: 100%|██████████| 88/88 [01:02<00:00,  1.64it/s, accuracy=0.352, cost=22.7] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 7, average cost 23.822403, average accuracy 0.158349


minibatch loop: 100%|██████████| 88/88 [01:03<00:00,  1.62it/s, accuracy=0.289, cost=21.6] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 8, average cost 22.656406, average accuracy 0.203660


minibatch loop: 100%|██████████| 88/88 [01:03<00:00,  1.62it/s, accuracy=0.226, cost=20.8] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 9, average cost 21.676073, average accuracy 0.242636


minibatch loop: 100%|██████████| 88/88 [01:04<00:00,  1.58it/s, accuracy=0.359, cost=19.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 10, average cost 20.541362, average accuracy 0.343313


minibatch loop: 100%|██████████| 88/88 [01:06<00:00,  1.51it/s, accuracy=0.533, cost=17.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 11, average cost 18.753993, average accuracy 0.499660


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.615, cost=16.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 12, average cost 17.549636, average accuracy 0.539520


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.578, cost=16.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 13, average cost 16.954394, average accuracy 0.554477


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.51it/s, accuracy=0.578, cost=15.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 14, average cost 16.549977, average accuracy 0.569731


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.389, cost=15.5]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 15, average cost 16.244339, average accuracy 0.575546


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.685, cost=15.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 16, average cost 15.935439, average accuracy 0.614464


minibatch loop: 100%|██████████| 88/88 [01:07<00:00,  1.51it/s, accuracy=0.7, cost=14.5]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 17, average cost 15.319269, average accuracy 0.725087


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.51it/s, accuracy=0.704, cost=14.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 18, average cost 14.990856, average accuracy 0.726967


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.737, cost=13.8]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 19, average cost 14.728409, average accuracy 0.728325


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.796, cost=13.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 20, average cost 14.433386, average accuracy 0.731996


minibatch loop: 100%|██████████| 88/88 [01:09<00:00,  1.48it/s, accuracy=0.763, cost=13.2]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 21, average cost 14.181820, average accuracy 0.732905


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.733, cost=13.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 22, average cost 13.975610, average accuracy 0.734363


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.756, cost=12.9]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 23, average cost 13.794073, average accuracy 0.735945


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.737, cost=12.7]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 24, average cost 13.632085, average accuracy 0.736533


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.73, cost=12.4] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 25, average cost 13.481259, average accuracy 0.736574


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.767, cost=12]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 26, average cost 13.343412, average accuracy 0.738414


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.77, cost=11.7] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 27, average cost 13.216582, average accuracy 0.737665


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.77, cost=11.5] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 28, average cost 13.101136, average accuracy 0.737896


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.77, cost=11.3] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 29, average cost 12.983406, average accuracy 0.739130


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.767, cost=11.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 30, average cost 12.857233, average accuracy 0.739508


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.77, cost=10.9] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 31, average cost 12.737795, average accuracy 0.741820


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.77, cost=10.7] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 32, average cost 12.625510, average accuracy 0.742181


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.778, cost=10.6]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 33, average cost 12.506044, average accuracy 0.743291


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.778, cost=10.4]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 34, average cost 12.397188, average accuracy 0.744147


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.785, cost=10.3]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 35, average cost 12.282721, average accuracy 0.746353


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.785, cost=10.1]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 36, average cost 12.175921, average accuracy 0.747302


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.50it/s, accuracy=0.793, cost=10]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 37, average cost 12.071008, average accuracy 0.749272


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.789, cost=9.91]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 38, average cost 11.962304, average accuracy 0.750446


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.796, cost=9.77]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 39, average cost 11.849359, average accuracy 0.751637


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.796, cost=9.76]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 40, average cost 11.758160, average accuracy 0.753051


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.804, cost=9.6] 
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 41, average cost 11.646122, average accuracy 0.754155


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.48it/s, accuracy=0.8, cost=9.45]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 42, average cost 11.548793, average accuracy 0.755320


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.8, cost=9.32]  
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 43, average cost 11.447508, average accuracy 0.757809


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.48it/s, accuracy=0.807, cost=9.24]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 44, average cost 11.367624, average accuracy 0.758501


minibatch loop: 100%|██████████| 88/88 [01:09<00:00,  1.48it/s, accuracy=0.811, cost=9.05]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 45, average cost 11.289521, average accuracy 0.760059


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.815, cost=8.86]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 46, average cost 11.201544, average accuracy 0.761459


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.49it/s, accuracy=0.833, cost=8.78]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 47, average cost 11.117919, average accuracy 0.764021


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.48it/s, accuracy=0.833, cost=8.64]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 48, average cost 11.018193, average accuracy 0.764925


minibatch loop: 100%|██████████| 88/88 [01:09<00:00,  1.48it/s, accuracy=0.867, cost=8.39]
minibatch loop:   0%|          | 0/88 [00:00<?, ?it/s]

epoch 49, average cost 10.939112, average accuracy 0.767058


minibatch loop: 100%|██████████| 88/88 [01:08<00:00,  1.48it/s, accuracy=0.87, cost=8.22] 

epoch 50, average cost 10.840128, average accuracy 0.768328





In [9]:
import random

random_index = random.randint(0, len(targets) - 1)
batch_x = inputs[random_index : random_index + 1]
print(
    'real:',
    ''.join(
        [idx2char[no] for no in targets[random_index : random_index + 1][0]]
    ),
)
batch_y = sparse_tuple_from(targets[random_index : random_index + 1])
pred = sess.run(model.preds, feed_dict = {model.X: batch_x})[0]
print('predicted:', ''.join([idx2char[no] for no in pred]))

real: say the word turn
predicted: say the word tut
