In [None]:
import torch
import numpy as np
import matplotlib.pyplot as plt

from vector_quantize_pytorch import VectorQuantize

In [None]:
device = 'cuda:1'

In [None]:
m1 = torch.load('../data/vq_hubert_60k_run5/quanitzer__L11_C2048_ckpt11000.pkl', map_location='cuda:0')
m2 = torch.load('../data/vq_w2vbert2_60k_run1/quantizer__L19_C2048_ckpt62500.pkl', map_location='cuda:0')

In [None]:
vq = VectorQuantize(
    dim=1024,
    codebook_size=2048,
    decay=0.8,
    commitment_weight=1
)
vq.to(device)

vq.eval()

In [None]:
new_state_dict = {}

for k, v in m2.items():
    new_state_dict[k] = v

vq.load_state_dict(new_state_dict)

In [None]:
import sys
sys.path.append('..')

In [None]:
import pdb
import torch
import joblib
import numpy as np
from tqdm import tqdm
from typing import Tuple
from pathlib import Path
import torch.nn.functional as F
from time import time

import matplotlib.pyplot as plt
from transformers import Wav2Vec2FeatureExtractor, AutoFeatureExtractor, WhisperFeatureExtractor

from src.utils import read_audio, find_files
from src.encoder import Wav2VecBertEncoder, HubertEncoder, WhisperEncoder, hubert_processor, whisper_processor
from src.configs import Wav2VecBertConfig, HubertEncoderConfig, WhisperEncoderConfig

In [None]:
HubertEncoderConfig.model_id = '../data/model/trimmed/hubert_11/'

processor = Wav2Vec2FeatureExtractor.from_pretrained(HubertEncoderConfig.model_id)
encoder = HubertEncoder(HubertEncoderConfig, quantize=False, compile=False, device=device)

In [None]:
encoder = Wav2VecBertEncoder(
    config=Wav2VecBertConfig(),
    compile=False,
    device=device
)

In [None]:
samples = 100
layer = 19

audio_files = find_files('/home/meraki/projects/tmp/flatfiles/indicvoices/', '.wav')
audio_files = np.random.choice(audio_files, samples, replace=False)

In [None]:
@torch.inference_mode()
def get_vq_dist(embeddings: torch.Tensor, quant: torch.nn.Module) -> Tuple:
    """
    Compute the distance between embeddings and centroids

    Args:
        embeddings (torch.Tensor): B, T, D
        centroids (torch.Tensor): K, D

    Returns:
        Tuple: (Value, Indices): B, T
    """
    # centroids, indices, commit_loss = quant(embeddings)
    # print(commit_loss)
    # distances = torch.cdist(embeddings, centroids)
    centroids = quant._codebook.embed
    distances = torch.cdist(embeddings, centroids)
    
    return torch.min(distances, dim=-1)

In [None]:
audio_distances = []
audio_tokens = []
embeddings = []

print(f'Computing embeddings')

for a in tqdm(audio_files, total=samples):
    audio = read_audio(a, 16_000)
    audio = audio[:, :160_000]

    ii = audio
    am = torch.ones_like(ii)
    ii = F.pad(ii, (0, 160_000-ii.shape[1]), value=0)
    am = F.pad(am, (0, 160_000-am.shape[1]), value=0)

    out = encoder(ii.to(device), am.to(device))
    out = out[layer]
    d = get_vq_dist(out, vq)

    embeddings.append(out.cpu().numpy())
    audio_distances.extend(d.values.cpu().numpy())
    audio_tokens.extend(d.indices.cpu().numpy())

seq_len, dim = embeddings[0].shape[1:]
embeddings = torch.from_numpy(np.array(embeddings)).reshape(samples*seq_len, dim)
audio_distances = np.array(audio_distances).reshape(-1, 1)
audio_tokens = np.array(audio_tokens).reshape(-1, 1)

print(f'Shape of embeddings: {embeddings.shape} and audio_distances: {audio_distances.shape} and audio_tokens: {audio_tokens.shape}')

norms = torch.linalg.vector_norm(embeddings, dim=-1)

random_embeddings = []
random_distances = []
random_tokens = []

print(f'Generating random embeddings')

for norm in tqdm(norms):
    random_vec = torch.randn((1, dim))
    random_vec = random_vec / torch.norm(random_vec)
    random_vec = random_vec * norm
    random_embeddings.append(random_vec)

    d = get_vq_dist(random_vec.to(device), vq)

    random_distances.append(d.values.detach().cpu().numpy())
    random_tokens.append(d.indices.detach().cpu().numpy())

random_distances = np.array(random_distances)
random_tokens = np.array(random_tokens)

print(f'Shape of random_embeddings: {len(random_embeddings)} and random_distances: {random_distances.shape} and random_tokens: {random_tokens.shape}')



In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].hist(audio_distances, alpha=0.75, label='Audio Tokens')
ax[0].hist(random_distances[:, :, 0], alpha=0.5, label='Random Tokens')
ax[0].set_title('Histogram of Distances')
ax[0].set_xlabel('Distance')
ax[0].set_ylabel('Frequency')
ax[0].legend()

# Plot the distribution of tokens across the centroids
ax[1].hist(audio_tokens, bins=100, alpha=0.75, label='Audio Tokens')
ax[1].hist(random_tokens[:, :, 0], bins=100, alpha=0.5, label='Random Tokens')
ax[1].set_title('Histogram of Tokens')
ax[1].set_xlabel('Token')
ax[1].set_ylabel('Frequency')
ax[1].legend()

In [None]:
fig, ax = plt.subplots(1, 2, figsize=(12, 6))
ax[0].hist(audio_distances, alpha=0.75, label='Audio Tokens')
ax[0].hist(random_distances[:, :, 0], alpha=0.5, label='Random Tokens')
ax[0].set_title('Histogram of Distances')
ax[0].set_xlabel('Distance')
ax[0].set_ylabel('Frequency')
ax[0].legend()

# Plot the distribution of tokens across the centroids
ax[1].hist(audio_tokens, bins=100, alpha=0.75, label='Audio Tokens')
ax[1].hist(random_tokens[:, :, 0], bins=100, alpha=0.5, label='Random Tokens')
ax[1].set_title('Histogram of Tokens')
ax[1].set_xlabel('Token')
ax[1].set_ylabel('Frequency')
ax[1].legend()

In [None]:
ind = np.unique(audio_tokens, return_counts=True)

In [None]:
eng = np.unique(audio_tokens, return_counts=True)

In [None]:
# plt.hist(ind[0])
plt.hist(ind[0],alpha=0.5, bins=50)
plt.show()

In [None]:
# plt.hist(ind[0])
plt.hist(eng[0],alpha=0.5, bins=50)
plt.show()

In [None]:
np.intersect1d(ind[0], eng[0]).shape

In [None]:
ind[0].shape

In [None]:
eng[0].shape