# Import Modules

In [159]:
import librosa
import torch
import matplotlib.pyplot as plt
import numpy as np

from sklearn.cluster import AgglomerativeClustering

from embedder import SpeechEmbedder
from utils.hparams import HParam
from utils.audio import Audio

# Build & Load

In [160]:
hp = HParam('config/default.yaml')
torch.cuda.set_device(hp.gpu)

embedder = SpeechEmbedder(hp).cuda()
chkpt_embed = torch.load('pretrained_d-vector_embedding.pt')
embedder.load_state_dict(chkpt_embed)
embedder.eval()

audio = Audio(hp)

dvec_wav, _ = librosa.load('data/short.wav', sr=16000)
dvec_wav=librosa.util.normalize(dvec_wav)
dvec_mel = audio.get_mel(dvec_wav)
dvec_mel = torch.from_numpy(dvec_mel).float().cuda()
dvec = embedder(dvec_mel)

# Clustering

In [161]:
display(dvec_wav.shape)

(41254,)

In [162]:
window_size = 20
mel_length = dvec_mel.shape[1]
mel_size = dvec_mel.shape[0]
display(mel_length)

258

In [163]:
dvec = torch.zeros(256,mel_length-window_size).cuda()

In [164]:
for i in range(mel_length-window_size):
    dvec[:,i] = embedder(dvec_mel[:,i:window_size+i])
display(dvec.shape)

torch.Size([256, 238])

In [165]:
dvec_numpy = dvec.cpu().detach().numpy()

In [166]:
clustering = AgglomerativeClustering(n_clusters=2,linkage='average').fit(dvec_numpy)

In [167]:
clustering.labels_

array([1, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0,
       0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0, 0, 0, 0, 0, 0, 1,
       0, 1, 1, 0, 0, 0, 1, 1, 0, 1, 1, 0, 1, 0, 0, 0, 1, 0, 0, 0, 1, 1,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 1, 1, 0,
       0, 1, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 0, 1, 0, 0, 0, 0,
       1, 1, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 1, 1, 0, 1, 0,
       1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0,
       0, 1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 1, 0, 0, 0, 0,
       1, 0, 0, 0, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0,
       0, 0, 1, 1, 0, 1, 0, 1, 0, 1, 0, 1, 0, 0, 0, 0, 0, 0, 1, 1, 0, 0,
       1, 1, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1])

# Plotting