In [None]:
import torch
import numpy as np
import joblib
import matplotlib.pyplot as plt
from transformers import Wav2Vec2FeatureExtractor, AutoFeatureExtractor
from tqdm import tqdm
from torch.cuda import empty_cache
import torch.nn.functional as F

import sys
sys.path.append('..')

from src.utils import read_audio, find_files
from src.encoder import Wav2VecBertEncoder, HubertEncoder, wav2vec_processor
from src.configs import Wav2VecBertConfig, HubertEncoderConfig

In [None]:
kmeans_path = '../data/kmeans/kmeans__L-1_C1024_ckpt150.pkl'
wave2vec2_kmeans = joblib.load(kmeans_path)
wav2vecbert_processor = AutoFeatureExtractor.from_pretrained(Wav2VecBertConfig.model_id)
wav2vecbert_encoder = Wav2VecBertEncoder(
    config=Wav2VecBertConfig(), device='cuda'
)

# kmeans_path = HubertEncoderConfig.quantizer_path
# hubert_kmeans = joblib.load(kmeans_path)
# hubert_encoder = HubertEncoder(quantize=False, device='cuda')
# processor = Wav2Vec2FeatureExtractor.from_pretrained(HubertEncoderConfig.model_id)

In [None]:
def get_dist(embeddings, centroids):
    # embeddings: B, T, D
    # centroids: K, D
    # return: B, T, K

    distances = torch.cdist(embeddings, centroids)
    return torch.min(distances, dim=-1).values

## Rough

1. Check the distance of any random embedding from the nearest centroid
2. Check the distance of a legit audio from the nearest centroids

In [None]:
rand_emb = np.random.rand(1, 160_000) # 10 second random audio
print(rand_emb.shape)

audiopath = '../data/test-clean/LibriSpeech/test-clean/1089/134686/1089-134686-0000.flac'
# audiopath = '/home/romit/.cache/huggingface/datasets/downloads/extracted/81c46ac239ac4614e07a0960bb4b7f62966b99a2c540db203593c975c49d4248/xs_chunks_0000/YOU0000000761_S0000321.wav'
audio = read_audio(audiopath, 16_000)
audio = audio[0, :160_000]

print(audio.shape)

In [None]:
with torch.no_grad():
    # proc = processor(rand_emb, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    # proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    # print(proc_ip.shape, proc_am.shape)
    # rand_out = hubert_encoder(proc_ip, proc_am)
    # rand_out = rand_out.to('cpu')
    
    # proc = processor(audio, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    # proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    # print(proc_ip.shape, proc_am.shape)
    # audio_out = hubert_encoder(proc_ip, proc_am)
    # audio_out = audio_out.to('cpu')

    ii, am = wav2vec_processor(rand_emb, wav2vecbert_processor)
    rand_out = wav2vecbert_encoder(ii.to('cuda'), am.to('cuda'))
    rand_out = rand_out.to('cpu')
    print(rand_out.shape)

    ii, am = wav2vec_processor(audio, wav2vecbert_processor)
    audio_out = wav2vecbert_encoder(ii.to('cuda'), am.to('cuda'))
    audio_out = audio_out.to('cpu')
    print(audio_out.shape)

In [None]:
# d1 = get_dist(rand_out, torch.from_numpy(hubert_kmeans.cluster_centers_))
# d2 = get_dist(audio_out, torch.from_numpy(hubert_kmeans.cluster_centers_))

d1 = get_dist(rand_out, torch.from_numpy(wave2vec2_kmeans.cluster_centers_))
d2 = get_dist(audio_out, torch.from_numpy(wave2vec2_kmeans.cluster_centers_))

d1.mean(), d2.mean()

## Funcs

In [None]:
def hubert_dist(
    encoder,
    processor,
    centroids,
    audio_path
):

    # Audio
    audio = read_audio(audio_path, 16_000)
    audio = audio[0, :160_000]
    proc = processor(audio, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    audio_out = encoder(proc_ip, proc_am)
    audio_out = audio_out.to('cpu')
    d1 = get_dist(audio_out, centroids)

    # Random embedding
    rand_emb = np.random.rand(1, 160_000)  # 10 second random audio
    proc = processor(rand_emb, return_tensors="pt", return_attention_mask=True, sampling_rate=16_000)
    proc_ip, proc_am = proc.input_values.to('cuda'), proc.attention_mask.to('cuda')
    rand_out = encoder(proc_ip, proc_am)
    rand_out = rand_out.to('cpu')
    d2 = get_dist(rand_out, centroids)

    return d1, d2


def wav2vec2_dist(
    encoder,
    centroids,
    audio_path
):

    # Audio
    audio = read_audio(audio_path, 16_000)
    audio = audio[0, :160_000]
    ii, am = wav2vec_processor(audio, wav2vecbert_processor)
    audio_out = encoder(ii.to('cuda'), am.to('cuda'))[layer]
    audio_out = audio_out.detach().to('cpu')
    d1 = get_dist(audio_out, centroids)

    # Random embedding
    # rand_emb = np.random.rand(1, 160_000)  # 10 second random audio
    # ii, am = wav2vec_processor(rand_emb, wav2vecbert_processor)
    # rand_out = encoder(ii.to('cuda'), am.to('cuda'))
    # rand_out = rand_out.detach().to('cpu')
    # rand_emb = torch.rand_like(audio_out)
    # d2 = get_dist(rand_emb, centroids)

    return d1

In [None]:
audio_files = find_files('../data/test-clean/LibriSpeech/test-clean/', ('.flac'))
print(len(audio_files))

In [None]:
layer = 15

In [None]:
embeddings = []

for a in tqdm(audio_files[:100], total=100):
    audio = read_audio(a, 16_000)
    audio = audio[0, :160_000]
    ii, am = wav2vec_processor(audio, wav2vecbert_processor)

    ii = F.pad(ii, (0, 0, 500-ii.shape[1], 0, 0, 0), value=0)
    am = F.pad(am, (500-am.shape[1], 0), value=0)

    out = wav2vecbert_encoder(ii.to('cuda'), am.to('cuda'))[layer]
    embeddings.append(out.detach().cpu().numpy())

In [None]:
embeddings.__len__()

In [None]:
temp = torch.from_numpy(np.array(embeddings)).reshape(500*34, 1024)

In [None]:
norms = torch.linalg.vector_norm(temp, dim=(-1))

In [None]:
plt.hist(norms)
plt.show()

In [None]:
random_vectors = []
for norm in tqdm(norms[:10000]):
    random_vec = torch.randn(1, 1024)
    random_vec = random_vec / torch.norm(random_vec)
    random_vec = random_vec * norm
    random_vectors.append(random_vec)

In [None]:
random_embeddings = torch.from_numpy(np.array(random_vectors))

In [None]:
plt.hist(torch.linalg.vector_norm(random_embeddings, dim=(-1)))
plt.show()

In [None]:
rand_pool = []
audio_pool = []

for f in tqdm(audio_files[:500]):
    # audio_dist, rand_dist = hubert_dist(hubert_encoder, processor, torch.from_numpy(hubert_kmeans.cluster_centers_), f)
    audio_dist = wav2vec2_dist(wav2vecbert_encoder, torch.from_numpy(wave2vec2_kmeans.cluster_centers_), f)

    # Add to a pool of distances
    audio_pool.extend(audio_dist)

    empty_cache()

In [None]:
rand_pool = []
for re in tqdm(random_embeddings):
    d = get_dist(re, torch.from_numpy(wave2vec2_kmeans.cluster_centers_))
    rand_pool.append(d)

In [None]:
a = np.concatenate(audio_pool)
r = np.concatenate(rand_pool)

In [None]:
a.shape

In [None]:
plt.hist(np.random.choice(a, 10000), alpha=1, label='Audio')
plt.hist(r, alpha=0.75, label='Random')
plt.legend(loc='upper left')
plt.xlabel('Distance of a token from centroid')
plt.savefig(f'wav2vec2_clusterdiff_xs_{layer}.png')
plt.show()

In [None]:
plt.hist(embeddings[0][0, 2])