In [1]:
# https://www.tensorflow.org/tutorials/text/nmt_with_attention
import tensorflow as tf
# physical_devices = tf.config.experimental.list_physical_devices('GPU')
# tf.config.experimental.set_memory_growth(physical_devices[0], True)

import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
from sklearn.model_selection import train_test_split

import io
import time

In [2]:
import os
import librosa # for audio processing
import IPython.display as ipd
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.io import wavfile # for audio processing
import warnings
warnings.filterwarnings("ignore")

In [3]:
def read_wav(PATH, LABELS=None, SAMPLE=None, default_rate=8000):
    if LABELS is None:
        LABELS = [f for f in os.listdir(os.getcwd()) if os.path.isdir(os.getcwd() + "\\" + f)]
        
    label_len = len(LABELS)

    # initialize the output array
    wave_count = np.zeros(label_len, dtype=np.int32)
    loss_count = np.zeros(label_len, dtype=np.int32)

    print("LABEL\tTOTAL\tREAD\t<1s COUNT")
    print("-----\t-----\t----\t---------")
    for i, label in enumerate(LABELS):
        files = os.listdir(os.path.join(PATH, label))                                 # list all the files
        waves = [f for f in files if f.endswith('.wav')]                              # get wave files
        wave_len = len(waves)                                                         # get number of wave files
        wave_count[i] = wave_len
        
        if SAMPLE is not None:
            waves = [waves[sample] for sample in np.random.randint(wave_len, size=SAMPLE)]
            wave_len = SAMPLE

        # initialize the temp output array
        tmp_wavData = np.zeros((wave_len, default_rate))
        tmp_wavLabels = np.zeros(wave_len)

        less_than_1s_count = 0
        for j, wav in enumerate(waves):
            path = os.path.join(PATH, label, wav)                                     # get path for each wave file
            samples, sample_rate = librosa.load(path, sr=default_rate)                # read file by librosa

            # padding the wave files
            if len(samples) < sample_rate:
                less_than_1s_count += 1                                               # count of files less than 1 second
                white_noise = np.random.normal(0, 0.02, sample_rate-len(samples))     # generate white noise for padding
                samples = np.concatenate((samples, white_noise), axis=None)           # padding the files that is less than 1s

            # output as np.array
            tmp_wavData[j, :] = samples                                               # temporary wave data
            tmp_wavLabels[j] = i                                                      # temporary wave labels

            # print the outcome every ten iterations
            if j+1 == wave_len:
                print("{}\t{}\t{}\t{}".format(label, wave_count[i], j+1, less_than_1s_count), end="\n")
                loss_count[i] = less_than_1s_count
            elif j % 10 == 9:
                print("{}\t{}\t{}\t{}".format(label, wave_count[i], j+1, less_than_1s_count), end="\r")

        if i == 0:
            wavData = tmp_wavData
            wavLabels = tmp_wavLabels
        else:
            wavData = np.concatenate((wavData, tmp_wavData), axis=0)                  # concatenate info of wave files
            wavLabels = np.concatenate((wavLabels, tmp_wavLabels), axis=None)         # concatenate following labels

    print()
    print("MISSION COMPLETE!!!")
    
    return wavData, wavLabels, wave_count.astype(np.int32), loss_count.astype(np.int32)

In [4]:
SAMPLE_RATE = 8000

phoneme_path = os.path.join(os.getcwd(), "Phonemes")
train_audio_path = os.path.join(os.path.dirname(os.getcwd()), "data", "train", "audio")

phoneme_dataframe = pd.read_csv(os.path.join(phoneme_path, "phonemes.csv"))
wav_array, label_array, total, loss = read_wav(PATH=train_audio_path, 
                                               LABELS=phoneme_dataframe.words, 
                                               SAMPLE=100,
                                               default_rate=SAMPLE_RATE)

LABEL	TOTAL	READ	<1s COUNT
-----	-----	----	---------
zero	2376	100	8
one	2370	100	9
two	2373	100	4
three	2356	100	6
four	2372	100	6
five	2357	100	10
six	2369	100	3
seven	2377	100	8
eight	2352	100	10
nine	2364	100	8

MISSION COMPLETE!!!


In [5]:
print("Check if there is any NaN or Inf number exist so that we can avoid problems while training")
print(f"NaN Number: {np.sum(np.isnan(wav_array))}")
print(f"Inf Number: {np.sum(np.isinf(wav_array))}")

Check if there is any NaN or Inf number exist so that we can avoid problems while training
NaN Number: 0
Inf Number: 0


In [6]:
def create_dataset(waves, create_size, labels=None, phonemes=None, min_sz=6, max_sz=10, padding=True):
    bind_size = np.random.randint(low=min_sz, high=max_sz+1, size=create_size)

    wav_simu = np.zeros(create_size, dtype=np.object)
    phone_simu = np.zeros(create_size, dtype=np.object)
    label_simu = np.zeros(create_size, dtype=np.object)

    for count, b_sz in enumerate(bind_size):
        index = np.random.randint(len(waves), size=b_sz)
        
        wav_simu[count] = np.array([waves[i] for i in index]).flatten()
        if padding:
            # padding white noise
            pad_sz = (max_sz - b_sz)*waves.shape[1]
            white_noise = np.random.normal(0, 0.02, size=pad_sz)
            wav_simu[count] = np.concatenate((wav_simu[count], white_noise), axis=None)
        
        if labels is not None:
            label_simu[count] = [int(labels[i]) for i in index]
            
        if phonemes is not None:
            phone_simu[count] = "<start> " + " ".join([phonemes[i] for i in label_simu[count]]) + " <end>"
        

        if count % 10 == 9:
            print(f"Simulating {count+1}th wave ", end="\r")
    print("\n\nSIMULATION COMPLETE!!!")
    
    # the output simulated wave data(which is wav_simu) will be a object numpy array with shape=(create_size, )
    # which is not a 2d array and cannot be put into tf.convert_to_tensor directly
    # to do this in case one would like to apply spatial pyramid pooling instead of padding white noise
    return wav_simu, label_simu, phone_simu


def tokenize(phone):
    tokenizer = tf.keras.preprocessing.text.Tokenizer(filters='')
    tokenizer.fit_on_texts(phone)
    tensor = tokenizer.texts_to_sequences(phone)
    tensor = tf.keras.preprocessing.sequence.pad_sequences(tensor, padding='post')

    return tensor, tokenizer


def convert(tokenizer, tensor):
    print("\nTOKEN\t--->\tWORDS")
    print("-----------------------")
    for t in tensor:
        if t != 0:
            print("{}\t--->\t{}".format(t, tokenizer.index_word[t]))

In [7]:
CREATE_SIZE = 1000
phonemes_array = phoneme_dataframe.phonemes.values
simu_wave, simu_label, simu_phoneme = create_dataset(waves=wav_array, 
                                                     create_size=CREATE_SIZE, 
                                                     labels=label_array, 
                                                     phonemes=phonemes_array, 
                                                     min_sz=6, 
                                                     max_sz=8)
simu_wave = np.vstack(simu_wave)

print(f"\nExample Phoneme Display: {simu_phoneme[0]}")

Simulating 1000th wave 

SIMULATION COMPLETE!!!

Example Phoneme Display: <start> F AO R Z IY R OW S IH K S F AY V Z IY R OW TH R IY F AO R <end>


In [8]:
class MFCC:
    def __init__(self, alpha, frame_size, frame_stride, n_fft, n_filter):
        self.alpha = alpha
        self.frame_size = frame_size
        self.frame_stride = frame_stride
        self.n_fft = n_fft
        self.n_filter = n_filter
        
    def mfcc(self, samples, sample_rate):
        samples_emphasized = self.pre_emphasis(samples)
        frames, total_samples_in_frame = self.framing(samples_emphasized, sample_rate)
        frames = self.hamming_window(frames, total_samples_in_frame)
        power_spectrum = self.stft(frames)
        fbank = self.filter_bank(power_spectrum, sample_rate)
        energy = self.log_energy(fbank)
        
        return np.column_stack((energy, fbank))
    
    def pre_emphasis(self, samples):
        return np.append(samples[0], samples[1:] - self.alpha*samples[:-1])
    
    def framing(self, samples, sample_rate):
        samples_in_frame = int(np.ceil(self.frame_size*sample_rate))                           # number of samples in one frame
        sample_stride = int(np.ceil(self.frame_stride*sample_rate))                            # sample stride in each iteration
        frame_num = int(np.ceil(
            (len(samples) - samples_in_frame)/sample_stride) + 1)                   # number of iterations

        padding_num = (frame_num-1)*sample_stride + samples_in_frame - len(samples)            # length for padding
        padding = np.zeros(padding_num)                                                        # prepare the padding array
        samples_padded = np.append(samples, padding)                                # padded sample array

        # index to pick all the overlapping samples
        index_each_frame = np.arange(samples_in_frame)
        index_each_stride = np.linspace(0, len(samples_padded) - samples_in_frame, frame_num).astype(np.int32)
        index = np.tile(index_each_frame, reps=(frame_num, 1)) + np.tile(index_each_stride, reps=(samples_in_frame, 1)).T

        return np.array([samples_padded[[i]] for i in index]), samples_in_frame                # frames is a 2D array
    
        
    def hamming_window(self, frames, samples_in_frame):
        # self.frames *= 0.54 - 0.46 * numpy.cos((2 * numpy.pi * n) / (self.total_samples_in_one_frame - 1))
        frames *= np.hamming(samples_in_frame)
        return frames
        
    def stft(self, frames):
        magnitude = np.abs(np.fft.rfft(frames, n=self.n_fft))                                  # magnitude of the FFT
        return (1.0/self.n_fft) * magnitude**2                                                 # power spectrum
    
    def filter_bank(self, frames, sample_rate):
        low_freq_mel = 0
        high_freq_mel = self.hz2mel(sample_rate/2)                                             # highest frequency of the Mel
        mel_points = np.linspace(low_freq_mel, high_freq_mel, self.n_filter+2)                 # Equally spaced in Mel scale
        bins = np.floor((self.n_fft+1) * self.mel2hz(mel_points) / sample_rate)                # bins for FFT
        
        fbank = np.zeros((self.n_filter, self.n_fft//2 + 1))
        for j in range(self.n_filter):
            for i in range(int(bins[j]), int(bins[j+1])):
                fbank[j, i] = (i - bins[j]) / (bins[j+1] - bins[j])
            for i in range(int(bins[j+1]), int(bins[j+2])):
                fbank[j, i] = (bins[j+2] - i) / (bins[j+2] - bins[j+1])
        
        mel_fbanks = np.dot(frames, fbank.T)
        mel_fbanks = np.where(mel_fbanks == 0, np.finfo(float).eps, mel_fbanks)
        mel_fbanks = 20 * np.log10(mel_fbanks)                                                 # dB
        
        return mel_fbanks
        
    def log_energy(self, mel_fbanks):
        return np.log(np.sum(mel_fbanks**2, axis=1))

    def hz2mel(self, hz):
        return 2595 * np.log10(1 + hz/700)  # Convert Hz to Mel
    
    def mel2hz(self, mel):
        return 700 * (10**(mel/2595.0) - 1) # Convert Mel to Hz

In [9]:
mfcc = MFCC(alpha=0.95, frame_size=0.025, frame_stride=0.01, n_fft=512, n_filter=12)
sample_mfcc = mfcc.mfcc(samples=simu_wave[0], sample_rate=SAMPLE_RATE)

In [10]:
ALPHA = 0.95
FRAME_SIZE = 0.025
FRAME_STRIDE = 0.01
N_FFT = 512
N_FILTER = 12

mfcc = MFCC(alpha=ALPHA, frame_size=FRAME_SIZE, frame_stride=FRAME_STRIDE, n_fft=N_FFT, n_filter=N_FILTER)
sample_mfcc = mfcc.mfcc(simu_wave[0], SAMPLE_RATE).flatten(order="C")


# adjust size so that it can fit in pBLSTM model
divider = (N_FILTER + 1)*256
new_size = int(((len(sample_mfcc)//divider + 1) * divider))
simu_wave_mfcc = np.zeros((CREATE_SIZE, new_size))

zero_padding = np.zeros(shape=(new_size - len(sample_mfcc)))
for i in np.arange(CREATE_SIZE):
    simu_wave_mfcc[i, :] = np.concatenate((mfcc.mfcc(simu_wave[i], 
                                                     SAMPLE_RATE).flatten(order="C"), 
                                           zero_padding))
    print(f"{i+1}th iteration", end="\r")

print("\n\nMISSION COMPELTE!!!")

1000th iteration

MISSION COMPELTE!!!


In [11]:
phoneme_tensor, phoneme_tokenizer = tokenize(simu_phoneme)
wav_tensor = tf.convert_to_tensor(simu_wave_mfcc, dtype=tf.float32)

print("Output Shape: {}".format(phoneme_tensor.shape))
print("Input Shape: {}".format(wav_tensor.shape))

for tensor in phoneme_tensor[:1]:
    convert(phoneme_tokenizer, tensor)
    print()

Output Shape: (1000, 34)
Input Shape: (1000, 13312)

TOKEN	--->	WORDS
-----------------------
10	--->	<start>
5	--->	f
19	--->	ao
3	--->	r
15	--->	z
6	--->	iy
3	--->	r
16	--->	ow
2	--->	s
13	--->	ih
14	--->	k
2	--->	s
5	--->	f
7	--->	ay
4	--->	v
15	--->	z
6	--->	iy
3	--->	r
16	--->	ow
17	--->	th
3	--->	r
6	--->	iy
5	--->	f
19	--->	ao
3	--->	r
11	--->	<end>



In [12]:
BATCH_SIZE = 1
LSTM_UNITS = 256
FINAL_TIMESTEP = 64
EMBEDDING_DIM = 128
WAV_SIZE = len(wav_tensor)
PHONEME_SIZE = len(phoneme_tokenizer.word_index) + 1
STEP_PER_EPOCH = WAV_SIZE // BATCH_SIZE

dataset = tf.data.Dataset.from_tensor_slices((wav_tensor, phoneme_tensor)).shuffle(WAV_SIZE)
dataset = dataset.batch(BATCH_SIZE, drop_remainder=True)

In [13]:
example_input_batch, example_target_batch = next(iter(dataset))

print(f"Original Input Shape: {example_input_batch.shape}")
print(f"Original Output Shape: {example_target_batch.shape}")

example_input_batch = tf.expand_dims(example_input_batch, 2)
print(f"Reshaped Input Shape: {example_input_batch.shape}")

Original Input Shape: (1, 13312)
Original Output Shape: (1, 34)
Reshaped Input Shape: (1, 13312, 1)


In [25]:
class Encoder(tf.keras.Model):
    def __init__(self, lstm_units, final_units, batch_sz, conv_filters, mfcc_dims):
        super(Encoder, self).__init__()
        self.lstm_units = lstm_units
        self.final_units = final_units
        self.batch_sz = batch_sz
        self.conv_filters = conv_filters
        self.mfcc_dims = mfcc_dims
        
        # Convolution layer to extract feature after MFCC
        self.conv_feat = tf.keras.layers.Conv1D(filters=self.conv_filters, 
                                                kernel_size=self.mfcc_dims, 
                                                padding='valid', 
                                                activation='relu', 
                                                strides=self.mfcc_dims)
        
    def call(self, x):
        '''
        build a pyramidal LSTM neural network encoder
        '''
        # Convolution Feature Extraction
        x = self.conv_feat(x)
        
        # initialize states for forward and backward
        initial_state_fw = None
        initial_state_bw = None
        
        pyramid_layer_number = 0
        while(x.shape[1] > self.final_units):
            pyramid_layer_number += 1
            # forward LSTM
            fw_output, fw_state_h, fw_state_c = self.build_lstm(True)(x, initial_state=initial_state_fw)

            # backward LSTM
            bw_output, bw_state_h, bw_state_c = self.build_lstm(False)(x, initial_state=initial_state_bw)

            x = tf.concat([fw_output, bw_output], -1)
            x = self.reshape_pyramidal(x)

            initial_state_fw = [fw_state_h, fw_state_c]
            initial_state_bw = [bw_state_h, bw_state_c]
 
        # print(f"Encoder pyramid layer number: {pyramid_layer_number}\n")
        return x, (fw_state_h, fw_state_c), (bw_state_h, bw_state_c)
    
    def build_lstm(self, back=True):
        '''
        build LSTM layer for forward and backward
        '''
        return tf.keras.layers.LSTM(units=self.lstm_units, 
                                    return_sequences=True, 
                                    return_state=True, 
                                    go_backwards=back)
    
    def reshape_pyramidal(self, outputs):
        '''
        After concatenating forward and backward outputs
        return the reshaped output
        '''
        batch_size, time_steps, num_units = outputs.shape
    
        return tf.reshape(outputs, (batch_size, -1, num_units * 2))

In [26]:
encoder = Encoder(lstm_units=LSTM_UNITS, 
                  final_units=FINAL_TIMESTEP, 
                  batch_sz=BATCH_SIZE, 
                  conv_filters=32, 
                  mfcc_dims=N_FILTER+1)

# If set the batch size greater than 4, memory of GPU will run out
sample_output, (fw_sample_state_h, fw_sample_state_c), bw_sample_state = encoder(example_input_batch)
print ('Encoder output shape: (batch size, sequence length, units) {}'.format(sample_output.shape))
print ('Encoder forward state h shape: (batch size, units) {}'.format(fw_sample_state_h.shape))
print ('Encoder forward state c shape: (batch size, units) {}'.format(fw_sample_state_h.shape))
print ('Encoder backward state h shape: (batch size, units) {}'.format(bw_sample_state[0].shape))

Encoder output shape: (batch size, sequence length, units) (1, 64, 1024)
Encoder forward state h shape: (batch size, units) (1, 256)
Encoder forward state c shape: (batch size, units) (1, 256)
Encoder backward state h shape: (batch size, units) (1, 256)


In [27]:
class BahdanauAttention(tf.keras.layers.Layer):
    def __init__(self, units):
        super(BahdanauAttention, self).__init__()
        self.W1 = tf.keras.layers.Dense(units)
        self.W2 = tf.keras.layers.Dense(units)
        self.V = tf.keras.layers.Dense(1)

    def call(self, query, values):
        # query hidden state shape == (batch_size, hidden size)
        # query_with_time_axis shape == (batch_size, 1, hidden size)
        # values shape == (batch_size, max_len, hidden size)
        # we are doing this to broadcast addition along the time axis to calculate the score
        query_with_time_axis = tf.expand_dims(query, 1)

        # score shape == (batch_size, max_length, 1)
        # we get 1 at the last axis because we are applying score to self.V
        # the shape of the tensor before applying self.V is (batch_size, max_length, units)
        score = self.V(tf.nn.tanh(
            self.W1(query_with_time_axis) + self.W2(values)))

        # attention_weights shape == (batch_size, max_length, 1)
        attention_weights = tf.nn.softmax(score, axis=1)

        # context_vector shape after sum == (batch_size, hidden_size)
        context_vector = attention_weights * values
        context_vector = tf.reduce_sum(context_vector, axis=1)

        return context_vector, attention_weights

In [28]:
attention_layer = BahdanauAttention(10)
attention_result, attention_weights = attention_layer(fw_sample_state_h, sample_output)

print("Attention result shape: (batch size, units) {}".format(attention_result.shape))
print("Attention weights shape: (batch_size, sequence_length, 1) {}".format(attention_weights.shape))

Attention result shape: (batch size, units) (1, 1024)
Attention weights shape: (batch_size, sequence_length, 1) (1, 64, 1)


In [29]:
class Decoder(tf.keras.Model):
    def __init__(self, target_sz, embedding_dim, decoder_units, batch_sz):
        super(Decoder, self).__init__()
        self.batch_sz = batch_sz
        self.decoder_units = decoder_units
        self.embedding = tf.keras.layers.Embedding(target_sz, embedding_dim)
        self.attention = BahdanauAttention(self.decoder_units)
        self.lstm = tf.keras.layers.LSTM(units=self.decoder_units, return_sequences=True, return_state=True)
        self.fc = tf.keras.layers.Dense(target_sz)


    def call(self, inputs, enc_hidden_h, enc_hidden_c, enc_output):
        '''
        build LSTM decoder
        '''
        # enc_output shape == (batch_size, max_length, hidden_size)
        context_vector, attention_weights = self.attention(enc_hidden_h, enc_output)

        # x shape after passing through embedding == (batch_size, 1, embedding_dim)
        x = self.embedding(inputs)

        # x shape after concatenation == (batch_size, 1, embedding_dim + hidden_size)
        x = tf.concat([tf.expand_dims(context_vector, 1), x], axis=-1)

        # passing the concatenated vector to the LSTM
        output, state_h, state_c = self.lstm(x)

        # output shape == (batch_size * 1, hidden_size)
        output = tf.reshape(output, (-1, output.shape[-1]))

        # output shape == (batch_size, vocab)
        x = self.fc(output)

        return x, (state_h, state_c), attention_weights

In [30]:
decoder = Decoder(target_sz=PHONEME_SIZE, 
                  embedding_dim=EMBEDDING_DIM, 
                  decoder_units=LSTM_UNITS, 
                  batch_sz=BATCH_SIZE)

sample_target_size = tf.random.uniform((BATCH_SIZE, 1))
sample_decoder_output, sample_decoder_hidden, attention_weights = decoder(
    inputs=sample_target_size, 
    enc_hidden_h=fw_sample_state_h, 
    enc_hidden_c=fw_sample_state_c, 
    enc_output=sample_output)

print ('Decoder output shape: (batch_size, vocab size) {}'.format(sample_decoder_output.shape))

Decoder output shape: (batch_size, vocab size) (1, 22)


In [31]:
optimizer = tf.keras.optimizers.Adam()
loss_object = tf.keras.losses.SparseCategoricalCrossentropy(
    from_logits=True, reduction='none')

def loss_function(real, pred):
    mask = tf.math.logical_not(tf.math.equal(real, 0))
    loss_ = loss_object(real, pred)

    mask = tf.cast(mask, dtype=loss_.dtype)
    loss_ *= mask

    return tf.reduce_mean(loss_)

In [32]:
checkpoint_dir = './training_checkpoints'
checkpoint_prefix = os.path.join(checkpoint_dir, "ckpt")
checkpoint = tf.train.Checkpoint(optimizer=optimizer,
                                 encoder=encoder,
                                 decoder=decoder)

In [33]:
@tf.function
def train_step(inp, targ, targ_tokenizer, enc_hidden):
    loss = 0

    with tf.GradientTape() as tape:
        # forward algorithm
        enc_output, (enc_hidden_h, enc_hidden_c), bw_enc_hidden = encoder(inp)
        dec_hidden_h, dec_hidden_c = enc_hidden_h, enc_hidden_c
        dec_input = tf.expand_dims([targ_tokenizer.word_index['<start>']] * BATCH_SIZE, 1)

        # Teacher forcing - feeding the target as the next input
        for t in range(1, targ.shape[1]):
            # passing enc_output to the decoder
            predictions, (dec_hidden_h, dec_hidden_c), _ = decoder(dec_input, dec_hidden_h, dec_hidden_c, enc_output)
            loss += loss_function(targ[:, t], predictions)

            # using teacher forcing
            dec_input = tf.expand_dims(targ[:, t], 1)
    
    # backward algorithm
    batch_loss = (loss / int(targ.shape[1]))

    variables = encoder.trainable_variables + decoder.trainable_variables

    gradients = tape.gradient(loss, variables)

    optimizer.apply_gradients(zip(gradients, variables))

    return batch_loss

In [None]:
# the line below is a debugger which will make tensorflow run step by step
# ValueError: tf.function-decorated function tried to create variables on non-first call.
tf.config.experimental_run_functions_eagerly(True)

EPOCHS = 10
for epoch in range(EPOCHS):
    start = time.time()

    # enc_hidden = encoder.initialize_hidden_state()
    enc_hidden = None
    total_loss = 0

    for (batch, (inp, targ)) in enumerate(dataset.take(STEP_PER_EPOCH)):
        inp = tf.expand_dims(inp, 2)
        batch_loss = train_step(inp, targ, phoneme_tokenizer, enc_hidden)
        
        total_loss += batch_loss
        
        if batch % 100 == 0:
            print('Epoch {} Batch {} Loss {:.4f}'.format(epoch + 1, batch, batch_loss.numpy()))

    # saving (checkpoint) the model every epoch
    checkpoint.save(file_prefix = checkpoint_prefix)

    print('Epoch {} Loss {:.4f}'.format(epoch + 1, total_loss / STEP_PER_EPOCH))
    print('Time taken for 1 epoch {} sec\n'.format(time.time() - start))

Epoch 1 Batch 0 Loss 2.0812


In [None]:
def predict(wave, max_input_len, max_output_len, tokenizer=None):
    attention_plot = np.zeros((max_output_len, max_input_len))
    
    inputs = tf.expand_dims(wave, 0)
    inputs = tf.expand_dims(inputs, 2)
    inputs = tf.convert_to_tensor(inputs)
    result = ''

    # hidden = [tf.zeros((1, units))]
    hidden = None
    enc_out, enc_hidden = encoder(inputs, hidden=hidden)

    dec_hidden = enc_hidden
    dec_input = tf.expand_dims([tokenizer.word_index['<start>']], 0)

    for t in np.arange(max_output_len):
        predictions, dec_hidden, attention_weights = decoder(dec_input,
                                                             dec_hidden,
                                                             enc_out)

        # storing the attention weights to plot later on
        attention_weights = tf.reshape(attention_weights, (-1, ))
        attention_plot[t] = attention_weights.numpy()

        predicted_id = tf.argmax(predictions[0]).numpy()

        result += tokenizer.index_word[predicted_id] + ' '

        if tokenizer.index_word[predicted_id] == '<end>':
            return result, wave, attention_plot

        # the predicted ID is fed back into the model
        dec_input = tf.expand_dims([predicted_id], 0)

    return result, wave, attention_plot

In [None]:
# function for plotting the attention weights
def plot_attention(attention, input_wav, output_phoneme):
    fig = plt.figure(figsize=(10,10))
    ax = fig.add_subplot(1, 1, 1)
    ax.matshow(attention, cmap='viridis')

    fontdict = {'fontsize': 14}

    # ax.set_xticklabels([''] + input_wav, fontdict=fontdict, rotation=90)
#     ax.set_xticklabels(range(len(input_wav)))
    ax.set_yticklabels([''] + output_phoneme, fontdict=fontdict)

    # ax.xaxis.set_major_locator(ticker.MultipleLocator(1))
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1))

    plt.show()

In [None]:
def translate(wave, max_in, max_out, tokenizer):
    result, _, attention_plot = predict(wave, max_in, max_out, tokenizer)

    print(f'Original Input Length: {len(wave)}')
    print(f'Predicted translation: {result}')

    attention_plot = attention_plot[:len(result.split(' ')), :50]
    plot_attention(attention_plot, np.arange(len(wave)), result.split(' '))

In [None]:
# restoring the latest checkpoint in checkpoint_dir
checkpoint.restore(tf.train.latest_checkpoint(checkpoint_dir))

In [None]:
# testing, testing_label, total, loss = read_wav(os.getcwd())
test_wave, test_label, test_phoneme = create_dataset(wav_array, 1, label_array, phonemes_array)

In [None]:
translate(test_wave, 8, sample_decoder_output.shape[1], phoneme_tokenizer)