In [1]:
# https://www.tensorflow.org/tutorials/text/nmt_with_attention
import tensorflow as tf
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], True)

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import io
import time

In [2]:
import os
import librosa # for audio processing
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.io import wavfile # for audio processing
import warnings
warnings.filterwarnings("ignore")

In [3]:
class WaveReader:
    def __init__(self, path, sample_rate, padding_type, read_size):
        self.path = path
        self.sample_rate = sample_rate
        self.padding_type = padding_type
        self.read_size = read_size

    def read(self, labels=None):
        print("LABEL\tTOTAL\tREAD\tSAVED\t<1s COUNT")
        print("-----\t-----\t----\t-----\t---------")
        
        if labels is None:
            labels = [f for f in os.listdir(path) if os.path.isdir(path + "\\" + f)]
            
        elif type(labels) == str:
            samples, total_wave_count, total_wave_read, total_loss_count = self.read_dir(dir_name=labels)
            sample_labels = np.repeat(labels, total_wave_read)
            
            print("\n\nMISSION COMPELTE!!!")
            return samples, sample_labels, total_wave_count, total_loss_count
                    
        label_len = len(labels)
        total_wave_count = np.zeros(label_len, dtype=np.int32)
        total_wave_read = np.zeros(label_len, dtype=np.int32)
        total_loss_count = np.zeros(label_len, dtype=np.int32)

        
        for i, lab in enumerate(labels):
            samp, total_wave_count[i], total_wave_read[i], total_loss_count[i] = self.read_dir(dir_name=lab)
            
            if i == 0:
                samples = samp
                sample_labels = np.repeat(lab, total_wave_read[i])
            else:
                samples = np.concatenate((samples, samp), axis=0)
                sample_labels = np.concatenate((sample_labels, np.repeat(lab, total_wave_read[i])), axis=None)
        
        print("\n\nMISSION COMPELTE!!!")
        return samples, sample_labels, total_wave_count, total_loss_count
    
    def read_dir(self, dir_name):
        dir_path = os.path.join(self.path, dir_name)
        wave_files = [f for f in os.listdir(dir_path) if f.endswith('.wav')]
        total_wave_files = len(wave_files)

        if self.read_size is not None:
            wave_files_read = self.read_size
        else:
            wave_files_read = total_wave_files

        samples = np.zeros((wave_files_read, self.sample_rate))
        less_than_1s_count = 0
        num_of_file_read = 0
        for i, wav_file in enumerate(wave_files):
            wave_file_path = os.path.join(dir_path, wav_file)
            samp, _ = librosa.load(wave_file_path, sr=self.sample_rate)

            pad_size = self.sample_rate - len(samp)
            if pad_size > 0:
                less_than_1s_count += 1
                if self.padding_type is None:
                    # None: than skip this wave file
                    continue

                elif self.padding_type == "white_noise":
                    # white_noise: pad white noise data behind
                    padding = np.random.normal(0, 0.02, pad_size)
                    samples[i, :] = np.concatenate((samp, padding), axis=None)
                    num_of_file_read += 1


                elif self.padding_type == "zero":
                    # zero: pad zeros behind
                    padding = np.zeros(pad_size)
                    samples[i, :] = np.concatenate((samp, padding), axis=None)
                    num_of_file_read += 1
            else:
                num_of_file_read += 1
                

            print("{}\t{}\t{}\t{}\t{}".format(dir_name, 
                                              total_wave_files, 
                                              i+1, 
                                              num_of_file_read, 
                                              less_than_1s_count), end="\r")
            
            if num_of_file_read == wave_files_read:
                break
                
        print()

        return samples, total_wave_files, wave_files_read, less_than_1s_count

In [4]:
SAMPLE_RATE = 8000

train_audio_path = os.path.join(os.path.dirname(os.getcwd()), "data", "train", "audio")
phoneme_path = os.path.join(os.getcwd(), "Phonemes")
phoneme_dataframe = pd.read_csv(os.path.join(phoneme_path, "phonemes.csv"))

reader = WaveReader(path=train_audio_path, 
                    sample_rate=SAMPLE_RATE, 
                    padding_type="white_noise", 
                    read_size=2000)

wav_array, label_array, total, loss = reader.read(labels=phoneme_dataframe.words)

LABEL	TOTAL	READ	SAVED	<1s COUNT
-----	-----	----	-----	---------
zero	2376	2000	2000	144
one	2370	2000	2000	224
two	2373	2000	2000	193
three	2356	2000	2000	191
four	2372	2000	2000	182
five	2357	2000	2000	169
six	2369	2000	2000	158
seven	2377	2000	2000	169
eight	2352	2000	2000	203
nine	2364	2000	2000	170


MISSION COMPELTE!!!


In [5]:
print("Check if there is any NaN or Inf number exist so that we can avoid problems while training")
print(f"NaN Number: {np.sum(np.isnan(wav_array))}")
print(f"Inf Number: {np.sum(np.isinf(wav_array))}")

Check if there is any NaN or Inf number exist so that we can avoid problems while training
NaN Number: 0
Inf Number: 0


In [6]:
class Preprocesser:
    def __init__(self, waves, create_size, min_sz=6, max_sz=8, padding_type="zero"):
        self.waves = waves
        self.wave_shape = waves.shape
        self.create_size = create_size
        self.min_sz = min_sz
        self.max_sz = max_sz
        self.padding_type = padding_type

        # get picker for combining waves and labels(phonemes)
        self.pickers = self.get_picker()

    def get_picker(self):
        size = np.random.randint(low=self.min_sz, 
                                 high=self.max_sz+1, 
                                 size=self.create_size)

        picker = np.zeros(self.create_size, dtype=np.object)
        for i, s in enumerate(size):
            picker[i] = np.random.choice(self.wave_shape[0]-1, size=self.max_sz, replace=False)[:s]
            
        return picker

    def simulate_wave(self):
        binded_length = self.wave_shape[1]*self.max_sz
        simu_wave = np.zeros((self.create_size, binded_length))
        
        
        for i, picker in enumerate(self.pickers):        
            tmp_simu_wave = np.array([self.waves[p] for p in picker]).flatten()
            
            pad_size = binded_length - len(tmp_simu_wave)
            if pad_size > 0:
                if self.padding_type == "white_noise":
                    # padding white noise
                    padding = np.random.normal(0, 0.02, size=pad_size)

                elif self.padding_type == "zero":
                    # padding zeros
                    padding = np.zeros(pad_size)

                simu_wave[i] = np.concatenate((tmp_simu_wave, padding), axis=None)
                
            else:
                simu_wave[i] = tmp_simu_wave
            
        print("Wave Data Simulation ... Done")
        return simu_wave

    def simulate_label(self, labels):
        simu_label = np.zeros(self.create_size, dtype=np.object)
        
        for i, picker in enumerate(self.pickers):
            simu_label[i] = np.array([labels[p] for p in picker])
            
        print("Label Simulation ... Done")
        return simu_label

    def simulate_phoneme(self, labels, label_dict, phoneme_dict):
        self.label_dict = label_dict
        self.phoneme_dict = phoneme_dict

        simu_phoneme = np.empty(self.create_size, dtype=np.object)
        for i, label in enumerate(labels):
            simu_phoneme[i] = " ".join([self.phoneme_translator(lab) for lab in label])
            simu_phoneme[i] = "<start> " + simu_phoneme[i] + " <end>"
            
        print("Phoneme Simulation... Done")
        return simu_phoneme

    def phoneme_translator(self, input_label):
        for i, label in enumerate(self.label_dict):
            if input_label == label:
                return self.phoneme_dict[i]
            
    def tokenize(self, phoneme):
        tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
        tokenizer.fit_on_texts(phoneme)
        tensor = tokenizer.texts_to_sequences(phoneme)
        tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor, padding='post')

        return tensor, tokenizer

    def show_convert(self, tensor, tokenizer):
        print("\nTOKEN\t--->\tWORDS")
        print("-----------------------")
        for t in tensor:
            if t != 0:
                print("{}\t--->\t{}".format(t, tokenizer.index_word[t]))

In [7]:
CREATE_SIZE = 20000

preprocesser = Preprocesser(waves=wav_array, 
                            create_size=CREATE_SIZE, 
                            min_sz=2, 
                            max_sz=3, 
                            padding_type="zero")

simu_wave = preprocesser.simulate_wave()
simu_label = preprocesser.simulate_label(label_array)
simu_phoneme = preprocesser.simulate_phoneme(labels=simu_label, 
                                             label_dict=phoneme_dataframe.words.values, 
                                             phoneme_dict=phoneme_dataframe.phonemes.values)

print(f"\nExample Label Display: {simu_label[0]}")
print(f"Example Phoneme Display: {simu_phoneme[0]}")

Wave Data Simulation ... Done
Label Simulation ... Done
Phoneme Simulation... Done

Example Label Display: ['two' 'four' 'two']
Example Phoneme Display: <start> T UW F AO R T UW <end>


In [11]:
class MFCC:
    def __init__(self, alpha, frame_size, frame_stride, n_fft, n_filter):
        self.alpha = alpha
        self.frame_size = frame_size
        self.frame_stride = frame_stride
        self.n_fft = n_fft
        self.n_filter = n_filter
        
    def mfcc(self, samples, sample_rate):
        samples_emphasized = self.pre_emphasis(samples)
        frames, total_samples_in_frame = self.framing(samples_emphasized, sample_rate)
        frames = self.hamming_window(frames, total_samples_in_frame)
        power_spectrum = self.stft(frames)
        fbank = self.filter_bank(power_spectrum, sample_rate)
        energy = self.log_energy(fbank)
        
        return np.column_stack((energy, fbank))
    
    def pre_emphasis(self, samples):
        return np.append(samples[0], samples[1:] - self.alpha*samples[:-1])
    
    def framing(self, samples, sample_rate):
        samples_in_frame = int(np.ceil(self.frame_size*sample_rate))                           # number of samples in one frame
        sample_stride = int(np.ceil(self.frame_stride*sample_rate))                            # sample stride in each iteration
        frame_num = int(np.ceil(
            (len(samples) - samples_in_frame)/sample_stride) + 1)                              # number of iterations

        padding_num = (frame_num-1)*sample_stride + samples_in_frame - len(samples)            # length for padding
        padding = np.zeros(padding_num)                                                        # prepare the padding array
        samples_padded = np.append(samples, padding)                                           # padded sample array

        # index to pick all the overlapping samples
        index_each_frame = np.arange(samples_in_frame)
        index_each_stride = np.linspace(0, len(samples_padded) - samples_in_frame, frame_num).astype(np.int32)
        index = np.tile(index_each_frame, reps=(frame_num, 1)) + np.tile(index_each_stride, reps=(samples_in_frame, 1)).T

        return np.array([samples_padded[[i]] for i in index]), samples_in_frame                # frames is a 2D array
    
        
    def hamming_window(self, frames, samples_in_frame):
        # self.frames *= 0.54 - 0.46 * numpy.cos((2 * numpy.pi * n) / (self.total_samples_in_one_frame - 1))
        frames *= np.hamming(samples_in_frame)
        return frames
        
    def stft(self, frames):
        magnitude = np.abs(np.fft.rfft(frames, n=self.n_fft))                                  # magnitude of the FFT
        return (1.0/self.n_fft) * magnitude**2                                                 # power spectrum
    
    def filter_bank(self, frames, sample_rate):
        low_freq_mel = 0
        high_freq_mel = self.hz2mel(sample_rate/2)                                             # highest frequency of the Mel
        mel_points = np.linspace(low_freq_mel, high_freq_mel, self.n_filter+2)                 # Equally spaced in Mel scale
        bins = np.floor((self.n_fft+1) * self.mel2hz(mel_points) / sample_rate)                # bins for FFT
        
        fbank = np.zeros((self.n_filter, self.n_fft//2 + 1))
        for j in range(self.n_filter):
            for i in range(int(bins[j]), int(bins[j+1])):
                fbank[j, i] = (i - bins[j]) / (bins[j+1] - bins[j])
            for i in range(int(bins[j+1]), int(bins[j+2])):
                fbank[j, i] = (bins[j+2] - i) / (bins[j+2] - bins[j+1])
        
        mel_fbanks = np.dot(frames, fbank.T)
        mel_fbanks = np.where(mel_fbanks == 0, np.finfo(float).eps, mel_fbanks)
        mel_fbanks = 20 * np.log10(mel_fbanks)                                                 # dB
        
        return mel_fbanks
        
    def log_energy(self, mel_fbanks):
        return np.log(np.sum(mel_fbanks**2, axis=1))

    def hz2mel(self, hz):
        return 2595 * np.log10(1 + hz/700)  # Convert Hz to Mel
    
    def mel2hz(self, mel):
        return 700 * (10**(mel/2595.0) - 1) # Convert Mel to Hz

In [12]:
class MFCCApplier:
    def __init__(self, alpha, frame_size, frame_stride, n_fft, n_filter, decide_size):
        self.alpha = alpha
        self.frame_size = frame_size
        self.frame_stride = frame_stride
        self.n_fft = n_fft
        self.n_filter = n_filter
        self.decide_size = decide_size

        self.mfcc = MFCC(alpha=ALPHA, 
                         frame_size=FRAME_SIZE, 
                         frame_stride=FRAME_STRIDE, 
                         n_fft=N_FFT, 
                         n_filter=N_FILTER)
        
    def apply(self, inputs):
        input_shape = inputs.shape
        # print("Shape of inputs: (input cases, sample size) {}".format(input_shape))
        
        sample = self.mfcc.mfcc(samples=inputs[0, :], sample_rate=input_shape[1])
        sample_shape = sample.shape
        # n_filter + 1 is the final output of MFCC
        # 1 stands for log energy
        print("Shape of inputs after MFCC: (time step, number of filters + 1) {}\n".format(sample_shape))
        
        old_size = sample_shape[0]*sample_shape[1]
        print("Size before reshape: {}".format(old_size))
        
        divider = (self.n_filter + 1) * self.decide_size
        new_size = int((sample_shape[0]*sample_shape[1]//divider + 1)*divider)
        print("Size after reshape: {}\n".format(new_size))
        
        outputs = np.zeros((input_shape[0], new_size))
        zero_padding = np.zeros(new_size - old_size)
        
        for i in np.arange(input_shape[0]):
            mfcced_wave = self.mfcc.mfcc(inputs[i, :], input_shape[1]).flatten(order="C")
            outputs[i, :] = np.concatenate((mfcced_wave, zero_padding))
            
            print(f"Applying MFCC and reshaping to {i+1}th case", end="\r")
            
        return outputs

In [13]:
ALPHA = 0.95
FRAME_SIZE = 0.025
FRAME_STRIDE = 0.01
N_FFT = 512
N_FILTER = 12
DECIDE_SIZE = 64

mfcc_applier = MFCCApplier(alpha=ALPHA, 
                           frame_size=FRAME_SIZE, 
                           frame_stride=FRAME_STRIDE, 
                           n_fft=N_FFT, 
                           n_filter=N_FILTER, 
                           decide_size=DECIDE_SIZE)

mfcced_simu_wave = mfcc_applier.apply(simu_wave)

Shape of inputs after MFCC: (time step, number of filters + 1) (99, 13)

Size before reshape: 1287
Size after reshape: 1664

Applying MFCC and reshaping to 20000th case

In [14]:
phoneme_tensor, phoneme_tokenizer = preprocesser.tokenize(simu_phoneme)
wav_tensor = tf.convert_to_tensor(mfcced_simu_wave, dtype=tf.float32)

print("Output Shape: {}".format(phoneme_tensor.shape))
print("Input Shape: {}".format(wav_tensor.shape))

for tensor in phoneme_tensor[:1]:
    preprocesser.show_convert(tensor, phoneme_tokenizer)
    print()

Output Shape: (20000, 17)
Input Shape: (20000, 1664)

TOKEN	--->	WORDS
-----------------------
2	--->	<start>
10	--->	t
19	--->	uw
7	--->	f
15	--->	ao
5	--->	r
10	--->	t
19	--->	uw
3	--->	<end>



In [15]:
BATCH_SIZE = 1
LSTM_UNITS = 256
FINAL_TIMESTEP = 64
EMBEDDING_DIM = 128
WAV_SIZE = len(wav_tensor)
PHONEME_SIZE = len(phoneme_tokenizer.word_index) + 1
STEP_PER_EPOCH = WAV_SIZE // BATCH_SIZE

dataset = tf.data.Dataset.from_tensor_slices((wav_tensor, phoneme_tensor)).shuffle(WAV_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

In [16]:
example_input_batch, example_target_batch = next(iter(dataset))

print(f"Original Input Shape: {example_input_batch.shape}")
print(f"Original Output Shape: {example_target_batch.shape}")

example_input_batch = tf.expand_dims(example_input_batch, 2)
print(f"Reshaped Input Shape: {example_input_batch.shape}")

Original Input Shape: (1, 1664)
Original Output Shape: (1, 17)
Reshaped Input Shape: (1, 1664, 1)


In [17]:
class Encoder(tf.keras.Model):
    def __init__(self, lstm_units, final_units, batch_sz, conv_filters, mfcc_dims):
        super(Encoder, self).__init__()
        self.lstm_units = lstm_units
        self.final_units = final_units
        self.batch_sz = batch_sz
        self.conv_filters = conv_filters
        self.mfcc_dims = mfcc_dims
        
        # Convolution layer to extract feature after MFCC
        self.conv_feat = tf.keras.layers.Conv1D(filters=self.conv_filters, 
                                                kernel_size=self.mfcc_dims, 
                                                padding='valid', 
                                                activation='relu', 
                                                strides=self.mfcc_dims)
        
        self.conv1 = tf.keras.layers.Conv1D(filters=self.conv_filters // 2, 
                                            kernel_size=5, 
                                            padding='same', 
                                            activation='relu', 
                                            strides=1)
        
    def call(self, x):
        '''
        build a pyramidal LSTM neural network encoder
        '''
        # Convolution Feature Extraction
        x = self.conv_feat(x)
        x = self.conv1(x)
        
        # initialize states for forward and backward
        initial_state_fw = None
        initial_state_bw = None
        
        pyramid_layer_number = 0
        while(x.shape[1] > self.final_units):
            pyramid_layer_number += 1
            # forward LSTM
            fw_output, fw_state_h, fw_state_c = self.build_lstm(True)(x, initial_state=initial_state_fw)

            # backward LSTM
            bw_output, bw_state_h, bw_state_c = self.build_lstm(False)(x, initial_state=initial_state_bw)

            x = tf.concat([fw_output, bw_output], -1)
            x = self.reshape_pyramidal(x)

            initial_state_fw = [fw_state_h, fw_state_c]
            initial_state_bw = [bw_state_h, bw_state_c]
 
        # print(f"Encoder pyramid layer number: {pyramid_layer_number}\n")
        return x, (fw_state_h, fw_state_c), (bw_state_h, bw_state_c)
    
    def build_lstm(self, back=True):
        '''
        build LSTM layer for forward and backward
        '''
        return tf.keras.layers.LSTM(units=self.lstm_units, 
                                    return_sequences=True, 
                                    return_state=True, 
                                    go_backwards=back)
    
    def reshape_pyramidal(self, outputs):
        '''
        After concatenating forward and backward outputs
        return the reshaped output
        '''
        batch_size, time_steps, num_units = outputs.shape
    
        return tf.reshape(outputs, (batch_size, -1, num_units * 2))

In [18]:
encoder = Encoder(lstm_units=LSTM_UNITS, 
                  final_units=FINAL_TIMESTEP, 
                  batch_sz=BATCH_SIZE, 
                  conv_filters=32, 
                  mfcc_dims=N_FILTER+1)

# If set the batch size greater than 4, memory of GPU will run out
sample_output, (fw_sample_state_h, fw_sample_state_c), bw_sample_state = encoder(example_input_batch)
print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print ('Encoder forward state h shape: (batch size, units) {}'.format(fw_sample_state_h.shape))
print ('Encoder forward state c shape: (batch size, units) {}'.format(fw_sample_state_h.shape))
print ('Encoder backward state h shape: (batch size, units) {}'.format(bw_sample_state[0].shape))

Encoder output shape: (batch size, sequence length, units) (1, 64, 1024)
Encoder forward state h shape: (batch size, units) (1, 256)
Encoder forward state c shape: (batch size, units) (1, 256)
Encoder backward state h shape: (batch size, units) (1, 256)


In [19]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, query, values):
        # query hidden state shape == (batch_size, hidden size)
        # query_with_time_axis shape == (batch_size, 1, hidden size)
        # values shape == (batch_size, max_len, hidden size)
        # we are doing this to broadcast addition along the time axis to calculate the score
        query_with_time_axis = tf.expand_dims(query, 1)

        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        # the shape of the tensor before applying self.V is (batch_size, max_length, units)
        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis) + self.W2(values)))

        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

In [20]:
attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(fw_sample_state_h, sample_output)

print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))

Attention result shape: (batch size, units) (1, 1024)
Attention weights shape: (batch_size, sequence_length, 1) (1, 64, 1)


In [21]:
class Decoder(tf.keras.Model):
    def __init__(self, target_sz, embedding_dim, decoder_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.decoder_units = decoder_units
        self.embedding = tf.keras.layers.Embedding(target_sz, embedding_dim)
        self.attention = BahdanauAttention(self.decoder_units)
        self.lstm = tf.keras.layers.LSTM(units=self.decoder_units, return_sequences=True, return_state=True)
        self.fc = tf.keras.layers.Dense(target_sz)


    def call(self, inputs, enc_hidden_h, enc_hidden_c, enc_output):
        '''
        build LSTM decoder
        '''
        # enc_output shape == (batch_size, max_length, hidden_size)
        context_vector, attention_weights = self.attention(enc_hidden_h, enc_output)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(inputs)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the LSTM
        output, state_h, state_c = self.lstm(x)

        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[-1]))

        # output shape == (batch_size, vocab)
        x = self.fc(output)

        return x, (state_h, state_c), attention_weights

In [22]:
decoder = Decoder(target_sz=PHONEME_SIZE, 
                  embedding_dim=EMBEDDING_DIM, 
                  decoder_units=LSTM_UNITS, 
                  batch_sz=BATCH_SIZE)

sample_target_size = tf.random.uniform((BATCH_SIZE, 1))
sample_decoder_output, sample_decoder_hidden, attention_weights = decoder(
    inputs=sample_target_size, 
    enc_hidden_h=fw_sample_state_h, 
    enc_hidden_c=fw_sample_state_c, 
    enc_output=sample_output)

print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))

Decoder output shape: (batch_size, vocab size) (1, 22)


In [23]:
optimizer = tf.keras.optimizers.Adam(learning_rate=0.001)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

In [24]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [25]:
@tf.function
def train_step(inp, targ, targ_tokenizer, enc_hidden):
    loss = 0

    with tf.GradientTape() as tape:
        # forward algorithm
        enc_output, (enc_hidden_h, enc_hidden_c), bw_enc_hidden = encoder(inp)
        dec_hidden_h, dec_hidden_c = enc_hidden_h, enc_hidden_c
        dec_input = tf.expand_dims([targ_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)

        # Teacher forcing - feeding the target as the next input
        for t in range(1, targ.shape[1]):
            # passing enc_output to the decoder
            predictions, (dec_hidden_h, dec_hidden_c), _ = decoder(dec_input, dec_hidden_h, dec_hidden_c, enc_output)
            loss += loss_function(targ[:, t], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(targ[:, t], 1)
    
    # backward algorithm
    batch_loss = (loss / int(targ.shape[1]))

    variables = encoder.trainable_variables + decoder.trainable_variables

    gradients = tape.gradient(loss, variables)

    optimizer.apply_gradients(zip(gradients, variables))

    return batch_loss

In [None]:
# the line below is a debugger which will make tensorflow run step by step
# ValueError: tf.function-decorated function tried to create variables on non-first call.
tf.config.experimental_run_functions_eagerly(True)

EPOCHS = 10
for epoch in range(EPOCHS):
    start = time.time()

    # enc_hidden = encoder.initialize_hidden_state()
    enc_hidden = None
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset.take(STEP_PER_EPOCH)):
        inp = tf.expand_dims(inp, 2)
        batch_loss = train_step(inp, targ, phoneme_tokenizer, enc_hidden)
        
        total_loss += batch_loss
        
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, batch_loss.numpy()))

    # saving (checkpoint) the model every epoch
    checkpoint.save(file_prefix = checkpoint_prefix)

    print('Epoch {} Loss {:.4f}'.format(epoch + 1, total_loss / STEP_PER_EPOCH))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

In [None]:
def predict(wave, max_input_len, max_output_len, tokenizer=None):
    attention_plot = np.zeros((max_output_len, max_input_len))
    
    inputs = tf.expand_dims(wave, 0)
    inputs = tf.expand_dims(inputs, 2)
    inputs = tf.convert_to_tensor(inputs)
    result = ''

    # hidden = [tf.zeros((1, units))]
    hidden = None
    enc_out, enc_hidden = encoder(inputs, hidden=hidden)

    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)

    for t in np.arange(max_output_len):
        predictions, dec_hidden, attention_weights = decoder(dec_input,
                                                             dec_hidden,
                                                             enc_out)

        # storing the attention weights to plot later on
        attention_weights = tf.reshape(attention_weights, (-1, ))
        attention_plot[t] = attention_weights.numpy()

        predicted_id = tf.argmax(predictions[0]).numpy()

        result += tokenizer.index_word[predicted_id] + ' '

        if tokenizer.index_word[predicted_id] == '<end>':
            return result, wave, attention_plot

        # the predicted ID is fed back into the model
        dec_input = tf.expand_dims([predicted_id], 0)

    return result, wave, attention_plot

In [None]:
# function for plotting the attention weights
def plot_attention(attention, input_wav, output_phoneme):
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')

    fontdict = {'fontsize': 14}

    # ax.set_xticklabels([''] + input_wav, fontdict=fontdict, rotation=90)
#     ax.set_xticklabels(range(len(input_wav)))
    ax.set_yticklabels([''] + output_phoneme, fontdict=fontdict)

    # ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()

In [None]:
def translate(wave, max_in, max_out, tokenizer):
    result, _, attention_plot = predict(wave, max_in, max_out, tokenizer)

    print(f'Original Input Length: {len(wave)}')
    print(f'Predicted translation: {result}')

    attention_plot = attention_plot[:len(result.split(' ')), :50]
    plot_attention(attention_plot, np.arange(len(wave)), result.split(' '))

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
# testing, testing_label, total, loss = read_wav(os.getcwd())
test_wave, test_label, test_phoneme = create_dataset(wav_array, 1, label_array, phonemes_array)

In [None]:
translate(test_wave, 8, sample_decoder_output.shape[1], phoneme_tokenizer)