# Train speech identity verification model

Having prepared the tfrecords file, we are ready to train and evaluate some models.

We use triplet loss, described at https://www.tensorflow.org/addons/tutorials/losses_triplet at the time this was written.

Note that triplet loss requires similar and negative examples to be present in the batch, ie, audios from the same person and from someone different. We'd like to have the batches as large as we can to increase this probability if we're shuffling samples. Another option (not covered here) would be to force the existence of valid batches.

In [None]:
import os

import tensorflow as tf
import tensorflow_io as tfio
import tensorflow_addons as tfa
from tensorflow.keras import Model
from tensorflow.keras import layers as L

%load_ext autoreload
%autoreload 2
from create_audio_tfrecords import AudioTarReader, PersonIdAudio

In [None]:
train_files = [x for x in os.listdir('data') if x.endswith('train.tfrecords.gzip')]
train_files = [os.path.join('data', x) for x in train_files]

# check if tfrecords file is OK
# notice GZIP compression + the deserialization function map
tfrecords_audio_dataset = tf.data.TFRecordDataset(
    train_files, compression_type='GZIP',
    num_parallel_reads=4
).map(PersonIdAudio.deserialize_from_tfrecords)

In [None]:
# count number of records
n_train_samples = sum(1 for _ in tfrecords_audio_dataset)
print(n_train_samples)

## Model definition

In [None]:
n_mel_bins = 80

def normalized_mel_spectrogram(x, sr=48000):
    spec_stride = 256
    spec_len = 1024

    spectrogram = tfio.audio.spectrogram(
        x, nfft=spec_len, window=spec_len, stride=spec_stride
    )

    num_spectrogram_bins = spec_len // 2 + 1  # spectrogram.shape[-1]
    lower_edge_hertz, upper_edge_hertz, num_mel_bins = 80.0, 10000.0, n_mel_bins
    linear_to_mel_weight_matrix = tf.signal.linear_to_mel_weight_matrix(
      num_mel_bins, num_spectrogram_bins, sr, lower_edge_hertz,
      upper_edge_hertz)
    mel_spectrograms = tf.tensordot(
      spectrogram, linear_to_mel_weight_matrix, 1)
    mel_spectrograms.set_shape(spectrogram.shape[:-1].concatenate(
      linear_to_mel_weight_matrix.shape[-1:]))

    # Compute a stabilized log to get log-magnitude mel-scale spectrograms.
    log_mel_spectrograms = tf.math.log(mel_spectrograms + 1e-6)
    avg = tf.math.reduce_mean(log_mel_spectrograms)
    std = tf.math.reduce_std(log_mel_spectrograms)

    return (log_mel_spectrograms - avg) / std


def BaseSpeechEmbeddingModel(inputLength=None, rnn_func=L.LSTM, rnn_units=64):
    # input is the first channel of the decoded mp3, ie, 
    # tfio.audio.decode_mp3(data)[:, 0]

    # inp = L.Input((inputLength,), name='input')
    # mel_spec = L.Lambda(lambda z: normalized_mel_spectrogram(z), name='normalized_spectrogram')(inp)

    # receive normalized mel spectrogram as input instead
    inp = L.Input((inputLength, n_mel_bins), name='input')
    mel_spec = inp

    # normalize the spectrogram
    # mel_spec = L.BatchNormalization()(mel_spec)
    # mel_spec = L.LayerNormalization()(mel_spec)

    x = L.Bidirectional(
        rnn_func(rnn_units, return_sequences=True)
    )(mel_spec)  # [b_s, seq_len, vec_dim]
    x = L.Bidirectional(
        rnn_func(rnn_units, return_sequences=False)
    )(x)  # [b_s, seq_len, vec_dim]

    x = L.Dense(rnn_units, activation=None)(x)  # No activation on final dense layer
    # L2 normalize embeddings
    # note: L2 returns normalized, norm
    x = L.Lambda(lambda z: tf.math.l2_normalize(z, axis=1), name='output')(x)
    
    output = x

    model = Model(inputs=[inp], outputs=[output])
    return model

In [None]:
m = BaseSpeechEmbeddingModel()
m.summary()

In [None]:
# samples = [x for x in tfrecords_audio_dataset.take(2)]
# samples[0][0].shape, samples[0][1].shape

## Training

Note: we need to `shuffle -> repeat -> batch` in this order.

In [None]:
batch_size = 96

return_mel_spec = True
def mp3_decode_fn(audio_bytes, audio_class):
    # check if limiting output size helps
    # return tfio.audio.decode_mp3(audio_bytes)[:, 0], audio_class
    audio_data = tfio.audio.decode_mp3(audio_bytes)[:, 0]
    # audio_data = tfio.audio.decode_mp3(audio_bytes)[0:48000 * 4, 0]
    if return_mel_spec:
        audio_data = normalized_mel_spectrogram(audio_data)
    return audio_data, audio_class

train_set = tfrecords_audio_dataset.map(
        # Reduce memory usage
        mp3_decode_fn,
        num_parallel_calls=tf.data.AUTOTUNE
    ).shuffle(
        10 * batch_size,
        reshuffle_each_iteration=True
    ).repeat(
    ).padded_batch(  # Vectorize your mapped function
        batch_size,  # batch size
        padded_shapes=([None, None], []),
        drop_remainder=True
    ).prefetch(  # Overlap producer and consumer works
        tf.data.AUTOTUNE
    )

In [None]:
# samples = [x for x in train_set.take(2)]
# samples[0][0].shape, samples[0][1].shape

In [None]:
m.compile(
    optimizer=tf.keras.optimizers.Adam(0.001),
    loss=tfa.losses.TripletSemiHardLoss()
)

In [None]:
history = m.fit(
    train_set,
    steps_per_epoch = n_train_samples // batch_size,
    epochs=100)

## Check similarities

In [None]:
from create_audio_tfrecords import AudioTarReader
audio_tarfile = 'data/cv-corpus-7.0-2021-07-21-pt.tar.gz'

atr = AudioTarReader(audio_tarfile)

In [None]:
val_audio_content = atr.retrieve_per_user_data('dev')

In [None]:
from tqdm.notebook import tqdm
def get_embedding(data, model):
    preds = []
    for x in tqdm(data):
        audio_data = tfio.audio.decode_mp3(x)[:, 0]
        audio_data = normalized_mel_spectrogram(audio_data)
        cur_pred = model.predict(
            tf.expand_dims(audio_data, axis=0)
        )[0]
        preds.append(cur_pred)

    return preds

In [None]:
audio_content_with_repeats = [x for x in val_audio_content if len(val_audio_content[x]) > 1]
print([len(val_audio_content[x]) for x in audio_content_with_repeats])

In [None]:
len(val_audio_content[audio_content_with_repeats[0]])

In [None]:
all_keys = audio_content_with_repeats
samples1 = val_audio_content[all_keys[4]]
samples2 = val_audio_content[all_keys[18]]
preds1 = get_embedding(samples1, m)
preds2 = get_embedding(samples2, m)

In [None]:
import numpy as np
def get_dists(list1, list2):
    ans = []
    for x in tqdm(list1):
        for y in list2:
            dist = np.linalg.norm(x-y)
            ans.append(dist)
    return ans

local_dists1 = get_dists(preds1, preds1)
local_dists2 = get_dists(preds2, preds2)
cross_dists = get_dists(preds1, preds2)

np.mean(local_dists1), np.mean(local_dists2), np.mean(cross_dists)