In [None]:
import sys
sys.path.append('..')

import torch
import numpy as np
import joblib
import matplotlib.pyplot as plt
from transformers import Wav2Vec2FeatureExtractor

from src.utils import read_audio
from src.encoder import Wav2VecBertEncoder, HubertEncoder
from src.configs import Wav2VecBertConfig, HubertEncoderConfig

In [None]:
kmeans_path = '../data/kmeans_faiss_centroids.pkl'
# kmeans_path = Wav2VecBertConfig.quantizer_path
wave2vec2_kmeans = joblib.load(kmeans_path)

kmeans_path = HubertEncoderConfig.quantizer_path
hubert_kmeans = joblib.load(kmeans_path)

wav2vecbert_encoder = Wav2VecBertEncoder(
    config=Wav2VecBertConfig(),
)

hubert_encoder = HubertEncoder(quantize=False)

processor = Wav2Vec2FeatureExtractor.from_pretrained(HubertEncoderConfig.model_id)

In [None]:
def get_dist(embeddings, centroids):
    # embeddings: B, T, D
    # centroids: K, D
    # return: B, T, K
    distances = torch.cdist(embeddings, centroids)
    return torch.min(distances, dim=-1).values

1. Check the distance of any random embedding from the nearest centroid
2. Check the distance of a legit audio from the nearest centroids

In [None]:
rand_emb = np.random.rand(1, 160_000) # 10 second random audio
print(rand_emb.shape)

audiopath = '../data/test-clean/LibriSpeech/test-clean/1089/134686/1089-134686-0000.flac'
# audiopath = '/home/romit/.cache/huggingface/datasets/downloads/extracted/81c46ac239ac4614e07a0960bb4b7f62966b99a2c540db203593c975c49d4248/xs_chunks_0000/YOU0000000761_S0000321.wav'
audio = read_audio(audiopath, 16_000)
audio = audio[0, :160_000]

print(audio.shape)

In [None]:
# proc = processor(rand_emb, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
# proc_ip, proc_am = proc.input_values, proc.attention_mask
# rand_out = hubert_encoder(proc_ip, proc_am)

# proc = processor(audio, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
# proc_ip, proc_am = proc.input_values, proc.attention_mask
# audio_out = hubert_encoder(proc_ip, proc_am)

rand_out = wav2vecbert_encoder(rand_emb, [])
print(rand_out.shape)

audio_out = wav2vecbert_encoder(audio, [])
print(audio_out.shape)

In [None]:
plt.plot(audio_out[0][0])

In [None]:
plt.plot(rand_out[0][0])

In [None]:
# d1 = get_dist(rand_out, torch.from_numpy(hubert_kmeans.cluster_centers_))
# d2 = get_dist(audio_out, torch.from_numpy(hubert_kmeans.cluster_centers_))

d1 = get_dist(rand_out, torch.from_numpy(wave2vec2_kmeans))#.cluster_centers_))
d2 = get_dist(audio_out, torch.from_numpy(wave2vec2_kmeans))#.cluster_centers_))

In [None]:
print(d1, d1.shape)
print(d2, d2.shape)

In [None]:
d1[0].mean(), d2[0].mean()

In [None]:
plt.hist(d1[0].detach().numpy())

plt.show()

In [None]:
plt.hist(d2[0].detach().numpy())