In [11]:
import os
import sys

# append top level dir to our path

sys.path.append("../")

In [12]:
import numpy as np
import torch
from torch.nn import functional as F

import utils
from data_loader import MelSpectrogramFixed
from inference import get_yaapt_f0, load_audio
from model.vc_dddm_mixup import DDDM, Wav2vec2
from model_f0_vqvae import Quantizer
from vocoder.hifigan import HiFi

mel_fn, w2v, f0_quantizer, model, net_v = None, None, None, None, None


def _get_speaker_embedding(a):
    global mel_fn, w2v, f0_quantizer, model, net_v

    if mel_fn is None:
        mel_fn = MelSpectrogramFixed(
            sample_rate=hps.data.sampling_rate,
            n_fft=hps.data.filter_length,
            win_length=hps.data.win_length,
            hop_length=hps.data.hop_length,
            f_min=hps.data.mel_fmin,
            f_max=hps.data.mel_fmax,
            n_mels=hps.data.n_mel_channels,
            window_fn=torch.hann_window,
        ).cuda()

    if w2v is None:
        # Load pre-trained w2v (XLS-R)
        w2v = Wav2vec2().cuda()

    if f0_quantizer is None:
        # Load model
        f0_quantizer = Quantizer(hps).cuda()
        utils.load_checkpoint(a.ckpt_f0_vqvae, f0_quantizer)
        f0_quantizer.eval()

    if model is None:
        model = DDDM(
            hps.data.n_mel_channels,
            hps.diffusion.spk_dim,
            hps.diffusion.dec_dim,
            hps.diffusion.beta_min,
            hps.diffusion.beta_max,
            hps,
        ).cuda()
        utils.load_checkpoint(a.ckpt_model, model, None)
        model.eval()

    if net_v is None:
        # Load vocoder
        net_v = HiFi(
            hps.data.n_mel_channels,
            hps.train.segment_size // hps.data.hop_length,
            **hps.model,
        ).cuda()
        utils.load_checkpoint(a.ckpt_voc, net_v, None)
        net_v.eval().dec.remove_weight_norm()

    # Convert audio
    # print(">> Converting each utterance...")
    src_name = os.path.splitext(os.path.basename(a.src_path))[0]
    audio = load_audio(a.src_path)

    src_mel = mel_fn(audio.cuda())
    src_length = torch.LongTensor([src_mel.size(-1)]).cuda()
    w2v_x = w2v(F.pad(audio, (40, 40), "reflect").cuda())

    try:
        f0 = get_yaapt_f0(audio.numpy())
    except:
        f0 = np.zeros((1, audio.shape[-1] // 80), dtype=np.float32)

    ii = f0 != 0
    f0[ii] = (f0[ii] - f0[ii].mean()) / f0[ii].std()
    f0 = torch.FloatTensor(f0).cuda()
    f0_code = f0_quantizer.code_extraction(f0)

    trg_name = os.path.splitext(os.path.basename(a.trg_path))[0]
    trg_audio = load_audio(a.trg_path)

    trg_mel = mel_fn(trg_audio.cuda())
    trg_length = torch.LongTensor([trg_mel.size(-1)]).to(device)

    with torch.no_grad():
        c = model.encode_speaker(
            w2v_x,
            f0_code,
            src_length,
            trg_mel,
            trg_length,
        )
        return c.cpu().detach().numpy()

In [13]:
from dotmap import DotMap


def get_speaker_embedding(path_to_wav):
    global hps, device, a
    a = DotMap()
    a.src_path = path_to_wav
    a.trg_path = path_to_wav
    a.ckpt_model = ".././ckpt/model_base.pth"
    a.ckpt_voc = ".././vocoder/voc_ckpt.pth"
    a.ckpt_f0_vqvae = ".././f0_vqvae/f0_vqvae.pth"
    a.t = 6
    config = os.path.join(os.path.split(a.ckpt_model)[0], "config.json")
    hps = utils.get_hparams_from_file(config)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    return _get_speaker_embedding(a)

In [14]:
speaker_embeddings = []

In [15]:
import glob

import tqdm

for p in tqdm.tqdm(glob.glob("D://vox1_test_wav/**/*.wav", recursive=True)):
    folders = p.split(os.sep)
    speaker_id = folders[-3]
    clip_id = folders[-2]
    utterance_id = folders[-1]
    emb = get_speaker_embedding(p)
    speaker_embeddings.append(
        {
            "speaker_id": speaker_id,
            "clip_id": clip_id,
            "utterance_id": utterance_id,
            "embedding": emb,
        }
    )

  0%|          | 0/4882 [00:00<?, ?it/s]

DEBUG:urllib3.connectionpool:Starting new HTTPS connection (1): huggingface.co:443
DEBUG:urllib3.connectionpool:https://huggingface.co:443 "HEAD /facebook/wav2vec2-xls-r-300m/resolve/main/config.json HTTP/1.1" 200 0
INFO:root:Loaded checkpoint '.././f0_vqvae/f0_vqvae.pth' (iteration 77)
INFO:root:Loaded checkpoint '.././ckpt/model_base.pth' (iteration 223)
INFO:root:Loaded checkpoint '.././vocoder/voc_ckpt.pth' (iteration 1169)
Removing weight norm...


  phi[lag_min:lag_max] = formula_nume/np.sqrt(formula_denom)
100%|██████████| 4882/4882 [2:05:51<00:00,  1.55s/it]  


In [16]:
import pandas as pd

print(len(speaker_embeddings))
df = pd.DataFrame(speaker_embeddings)

import pickle

with open("speaker_embeddings.pickle", "wb") as f:
    pickle.dump(df, f)

df

4882


Unnamed: 0,speaker_id,clip_id,utterance_id,embedding
0,id10270,5r0dWxy17C8,00001.wav,"[[[-0.9422593], [-0.37463865], [0.19461837], [..."
1,id10270,5r0dWxy17C8,00002.wav,"[[[-1.4136827], [-0.7891733], [0.3994071], [0...."
2,id10270,5r0dWxy17C8,00003.wav,"[[[-1.1209184], [-0.8950034], [1.3301648], [0...."
3,id10270,5r0dWxy17C8,00004.wav,"[[[-0.57233125], [-0.1521129], [0.46998602], [..."
4,id10270,5r0dWxy17C8,00005.wav,"[[[-0.7774991], [-0.42959046], [0.93772507], [..."
...,...,...,...,...
4877,id10270,5r0dWxy17C8,00004.wav,"[[[-0.2039178], [1.0564528], [-0.015266762], [..."
4878,id10270,5r0dWxy17C8,00005.wav,"[[[-0.6076309], [0.59529996], [0.37479544], [-..."
4879,id10270,5r0dWxy17C8,00006.wav,"[[[-0.24143918], [-0.19603784], [-0.43727106],..."
4880,id10270,5r0dWxy17C8,00007.wav,"[[[-0.22903721], [-0.55159664], [0.6781072], [..."


In [17]:
sensitivity = 0

for i, row1 in tqdm.tqdm(df.iterrows()):
    for j, row2 in df.iterrows():
        if row1["speaker_id"] == row2["speaker_id"]:
            continue

        emb1 = np.asarray(row1["embedding"]).squeeze()
        emb2 = np.asarray(row2["embedding"]).squeeze()

        dist = np.linalg.norm(emb1 - emb2)
        if dist > sensitivity:
            sensitivity = dist

print(f"sensitivity={sensitivity}")

4882it [24:05,  3.38it/s]

sensitivity=24.855695724487305





In [18]:
with open("voxceleb1_sensitivity.txt", "w") as file:
    file.write(f"sensitivity={sensitivity}")