# Speech Identity Inference

Let's check if the pretrained model can really identify speakers.

In [None]:
import os
import numpy as np
import pandas as pd
from sklearn import metrics

from tqdm.notebook import tqdm
from IPython.display import Audio
from matplotlib import pyplot as plt
%matplotlib inline

import tensorflow as tf
import tensorflow_io as tfio
import tensorflow_addons as tfa

from train_speech_id_model import BaseSpeechEmbeddingModel
from create_audio_tfrecords import AudioTarReader, PersonIdAudio

sr = 48000

In [None]:
m = BaseSpeechEmbeddingModel()
m.summary()

In [None]:
# 90.cpkt: auc = 0.9525
# 110.cpkt: auc = 0.9533
chkpt = 'temp/cp-0110.ckpt'
m.load_weights(chkpt)
m.compile(
    optimizer=tf.keras.optimizers.Adam(0.0006),
    loss=tfa.losses.TripletSemiHardLoss()
)
# m.save('speech-id-model-110')

In [None]:
# changing the corpus to other languages allows evaluating how the model transfers between languages
dev_dataset = tfrecords_audio_dataset = tf.data.TFRecordDataset(
    'data/cv-corpus-7.0-2021-07-21-en.tar.gz_dev.tfrecords.gzip', compression_type='GZIP',
#    'data/cv-corpus-7.0-2021-07-21-en.tar.gz_test.tfrecords.gzip', compression_type='GZIP',
    num_parallel_reads=4
).map(PersonIdAudio.deserialize_from_tfrecords)

In [None]:
samples = [x for x in dev_dataset.take(2500)]
# decode audio
samples = [(tfio.audio.decode_mp3(x[0])[:, 0], x[1]) for x in samples]

In [None]:
# is the audio decoded correctly?
Audio(samples[10][0], rate=sr)

In [None]:
# compute the embeddings
embeddings = []
for audio_data, person_id in tqdm(samples):
    cur_emb = m.predict(
        tf.expand_dims(audio_data, axis=0)
    )[0]
    embeddings.append(cur_emb)

## Check embedding quality

Ideally, embeddings from the same person should look the same.

In [None]:
n_speakers = len(set([x[1].numpy() for x in samples]))
print(f'Loaded {n_speakers} different speakers')

In [None]:
pairwise_diff = {'same': [], 'different': []}
for p in tqdm(range(len(samples))):
    for q in range(p + 1, len(samples)):
        id_1 = samples[p][1]
        id_2 = samples[q][1]
        dist = np.linalg.norm(embeddings[p] - embeddings[q])
        if id_1 == id_2:
            pairwise_diff['same'].append(dist)
        else:
            pairwise_diff['different'].append(dist)

In [None]:
plt.figure(figsize=(12, 8))
plt.boxplot([pairwise_diff[x] for x in pairwise_diff])
plt.xticks([k + 1 for k in range(len(pairwise_diff))], [x for x in pairwise_diff])
plt.ylabel('Embedding distance')
plt.title('Boxplot of speaker identifiability')

In [None]:
# what do we care about?
# given that 2 samples are different, we don't want to predict `same`
# secondarily, given that 2 samples are the same, we want to predict `same`

# threshold - alpha from 0 (median of same) to 1 (median of different)
alpha = 0.2

# if using the validation set, we can calibrate t
t = np.median(pairwise_diff['same']) + alpha * (np.median(pairwise_diff['different']) - np.median(pairwise_diff['same']))

specificity = np.sum(np.array(pairwise_diff['different']) > t) / len(pairwise_diff['different'])
sensitivity = np.sum(np.array(pairwise_diff['same']) < t) / len(pairwise_diff['same'])

print('Sensitivity, specificity = ', sensitivity, specificity)

same_lbl = [0] * len(pairwise_diff['same'])
diff_lbl = [1] * len(pairwise_diff['different'])
scores = pairwise_diff['same'] + pairwise_diff['different']

# scale scores to range [0,1] and chande threshold accordingly
scores = np.array(scores) * 0.5
t = t * 0.5

labels = same_lbl + diff_lbl
len(scores), len(labels)

In [None]:
fpr, tpr, thresholds = metrics.roc_curve(labels, scores, pos_label=1)

In [None]:
plt.figure(figsize=(12, 8))
roc_auc = metrics.roc_auc_score(labels, scores)

plt.title(f'ROC curve: AUC = {np.round(roc_auc, 4)} {chkpt}')
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')

plt.plot(fpr, tpr)
plt.plot([0, 1], [0, 1])

In [None]:
plt.figure(figsize=(12, 8))
plt.title('Point of operation')
plt.plot(thresholds, 1 - fpr, label='Specificity')
plt.plot(thresholds, tpr, label='Sensitivity')
plt.plot([t, t], [0, 1], label='Threshold')
plt.xlabel('Threshold level')
plt.xlim([0, 1])
plt.legend()

## Select best model on validation

Strategy: compute loss but don't sort validation set, so there are multiple voice repeats in a batch. Also makes the evaluation consistent. Batch size should be as big as possible.

In [None]:
triplet_loss = tfa.losses.TripletSemiHardLoss()

In [None]:
# compute all predictions
def mp3_decode_fn(audio_bytes, audio_class):
    audio_data = tfio.audio.decode_mp3(audio_bytes)[:, 0]
    return audio_data, audio_class

In [None]:
all_preds = []
all_labels = []
for x in tqdm(dev_dataset.take(1300).map(mp3_decode_fn)):
    s = x[0]
    all_preds.append(m.predict(
        tf.expand_dims(x[0], axis=0)
    )[0])
    all_labels.append(x[1].numpy())

In [None]:
len(all_preds)

In [None]:
batch_size = 128
n_batches = len(all_preds) // batch_size
vec_size = len(all_preds[0])

np_preds = np.reshape(all_preds[0:batch_size * n_batches], (n_batches, batch_size, vec_size))
np_labls = np.reshape(all_labels[0:batch_size * n_batches], (n_batches, batch_size))

In [None]:
total_loss = 0
for lbl, pred in zip(np_labls, np_preds):
    total_loss += triplet_loss(lbl, pred).numpy()
total_loss = total_loss / len(lbl)
print(f'Total loss: {total_loss}')

In [None]:
all_checkpoints = [x.split('.')[0] + '.ckpt' for x in os.listdir('temp') if 'ckpt.index' in x]
all_results = []
for checkpoint in tqdm(all_checkpoints):
    m.load_weights(os.path.join('temp', checkpoint))

    all_preds = []
    all_labels = []
    n_items = 4600
    for x in tqdm(dev_dataset.take(n_items).map(mp3_decode_fn),
                 total=n_items, leave=False):
    # for x in tqdm(dev_dataset.map(mp3_decode_fn),
    #               leave=False):
        s = x[0]
        all_preds.append(m.predict(
            tf.expand_dims(x[0], axis=0)
        )[0])
        all_labels.append(x[1].numpy())

    batch_size = 128
    n_batches = len(all_preds) // batch_size
    vec_size = len(all_preds[0])

    np_preds = np.reshape(all_preds[0:batch_size * n_batches], (n_batches, batch_size, vec_size))
    np_labls = np.reshape(all_labels[0:batch_size * n_batches], (n_batches, batch_size))

    total_loss = 0
    for lbl, pred in zip(np_labls, np_preds):
        total_loss += triplet_loss(lbl, pred).numpy()
    total_loss = total_loss / len(lbl)
    cur_result = {
        'checkpoint': checkpoint,
        'val_loss': total_loss
    }
    print(cur_result)
    all_results.append(cur_result)

In [None]:
df_val = pd.DataFrame(all_results)
df_val['idx'] = df_val.checkpoint.apply(lambda z: int(z.split('.')[0].split('-')[1]))
df_val = df_val.set_index('idx')

In [None]:
df_val.to_csv('val_triplet_loss.csv')

In [None]:
# df_val

In [None]:
df_val.plot()