In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import joblib
import matplotlib.pyplot as plt
from transformers import Wav2Vec2FeatureExtractor

from src.utils import read_audio, find_files
from src.encoder import Wav2VecBertEncoder, HubertEncoder
from src.configs import Wav2VecBertConfig, HubertEncoderConfig

In [None]:
kmeans_path = '../data/kmeans/kmeans_faiss_centroids.pkl'
# kmeans_path = Wav2VecBertConfig.quantizer_path
wave2vec2_kmeans = joblib.load(kmeans_path)

kmeans_path = HubertEncoderConfig.quantizer_path
hubert_kmeans = joblib.load(kmeans_path)

wav2vecbert_encoder = Wav2VecBertEncoder(
    config=Wav2VecBertConfig(), device='cuda'
)

hubert_encoder = HubertEncoder(quantize=False, device='cuda')
processor = Wav2Vec2FeatureExtractor.from_pretrained(HubertEncoderConfig.model_id)

In [None]:
def get_dist(embeddings, centroids):
    # embeddings: B, T, D
    # centroids: K, D
    # return: B, T, K
    distances = torch.cdist(embeddings, centroids)
    return torch.min(distances, dim=-1).values

1. Check the distance of any random embedding from the nearest centroid
2. Check the distance of a legit audio from the nearest centroids

In [None]:
rand_emb = np.random.rand(1, 160_000) # 10 second random audio
print(rand_emb.shape)

audiopath = '../data/test-clean/LibriSpeech/test-clean/1089/134686/1089-134686-0000.flac'
# audiopath = '/home/romit/.cache/huggingface/datasets/downloads/extracted/81c46ac239ac4614e07a0960bb4b7f62966b99a2c540db203593c975c49d4248/xs_chunks_0000/YOU0000000761_S0000321.wav'
audio = read_audio(audiopath, 16_000)
audio = audio[0, :160_000]

print(audio.shape)

In [None]:
with torch.no_grad():
    proc = processor(rand_emb, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    print(proc_ip.shape, proc_am.shape)
    rand_out = hubert_encoder(proc_ip, proc_am)
    rand_out = rand_out.to('cpu')
    
    proc = processor(audio, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    print(proc_ip.shape, proc_am.shape)
    audio_out = hubert_encoder(proc_ip, proc_am)
    audio_out = audio_out.to('cpu')

    # rand_out = wav2vecbert_encoder(rand_emb, [])
    # rand_out = rand_out.to('cpu')
    # print(rand_out.shape)

    # audio_out = wav2vecbert_encoder(audio, [])
    # audio_out = audio_out.to('cpu')
    # print(audio_out.shape)

In [None]:
plt.plot(audio_out[0][0])

In [None]:
plt.plot(rand_out[0][0])

In [None]:
# d1 = get_dist(rand_out, torch.from_numpy(hubert_kmeans.cluster_centers_))
# d2 = get_dist(audio_out, torch.from_numpy(hubert_kmeans.cluster_centers_))

d1 = get_dist(rand_out, torch.from_numpy(wave2vec2_kmeans))#.cluster_centers_))
d2 = get_dist(audio_out, torch.from_numpy(wave2vec2_kmeans))#.cluster_centers_))

In [None]:
def hubert_dist(
    encoder,
    processor,
    centroids,
    audio_path
):

    # Audio
    audio = read_audio(audio_path, 16_000)
    audio = audio[0, :160_000]
    proc = processor(audio, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    audio_out = encoder(proc_ip, proc_am)
    audio_out = audio_out.to('cpu')
    d1 = get_dist(audio_out, centroids)

    # Random embedding
    rand_emb = np.random.rand(1, 160_000)  # 10 second random audio
    proc = processor(rand_emb, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    rand_out = encoder(proc_ip, proc_am)
    rand_out = rand_out.to('cpu')
    d2 = get_dist(rand_out, centroids)

    return d1, d2


def wav2vec2_dist(
    encoder,
    centroids,
    audio_path
):

    # Audio
    audio = read_audio(audio_path, 16_000)
    audio = audio[0, :160_000]
    audio_out = encoder(audio, [])
    audio_out = audio_out.to('cpu')
    d1 = get_dist(audio_out, centroids)

    # Random embedding
    rand_emb = np.random.rand(1, 160_000)  # 10 second random audio
    rand_out = encoder(rand_emb, [])
    rand_out = rand_out.to('cpu')
    d2 = get_dist(rand_out, centroids)

    return d1, d2

In [None]:
audio_files = find_files('../data/test-clean/LibriSpeech/test-clean/', ('.flac'))
print(len(audio_files))

In [None]:
from tqdm import tqdm

In [None]:
rand_pool = []
audio_pool = []

for f in tqdm(audio_files):
    audio_dist, rand_dist = hubert_dist(hubert_encoder, processor, torch.from_numpy(hubert_kmeans.cluster_centers_), f)
    # audio_dist, rand_dist = wav2vec2_dist(wav2vecbert_encoder, torch.from_numpy(wave2vec2_kmeans), f)

    # Add to a pool of distances
    audio_pool.extend(audio_dist)
    rand_pool.extend(rand_dist)

In [None]:
a = np.concatenate(audio_pool)
r = np.concatenate(rand_pool)

In [None]:
plt.hist(a, alpha=1, label='Audio')
plt.hist(r, alpha=0.75, label='Random')
plt.legend(loc='upper left')
plt.xlabel('Distance of a token from centroid')
# plt.savefig('wav2vec2_clusterdiff.png')
plt.savefig('hubert_clusterdiff.png')
plt.show()