# kNN-VC and LinearVC experiments using parallel data

Herman Kamper, 2024

In [1]:
from datetime import datetime
from numpy import linalg
from pathlib import Path
from tqdm.notebook import tqdm
import celer
import IPython.display as display
import numpy as np
import sys
import torch
import torchaudio

from utils import fast_cosine_dist

In [27]:
from resample_vad import speakers
device = "cuda"

## Models

In [3]:
wavlm = torch.hub.load("bshall/knn-vc", "wavlm_large", trust_repo=True, device=device)

Using cache found in /home/kamperh/.cache/torch/hub/bshall_knn-vc_master


WavLM-Large loaded with 315,453,120 parameters.


In [4]:
hifigan, _ = torch.hub.load("bshall/knn-vc", "hifigan_wavlm", trust_repo=True, device=device, prematched=True)

Using cache found in /home/kamperh/.cache/torch/hub/bshall_knn-vc_master


Removing weight norm...
[HiFiGAN] Generator loaded with 16,523,393 parameters.


## LinearVC using parallel utterances (single)

In [5]:
wav_dir = Path("/home/kamperh/scratch/vctk/wav/")
k_top = 1

In [20]:
# Projection matrix

source = "p225"  # Southern English
target = "p226"  # Surrey
# target = "p232"  # Southern English
# target = "p228"  # Southern English
# target = "p234"  # Scottish
# target = "p323"  # South African
# target = "p347"  # South African
# target = "p376"  # Indian

# source_wav_fn = wav_dir / source / f"{source}_002.wav"
# target_wav_fn = wav_dir / target / f"{target}_002.wav"
# source_wav_fn = wav_dir / source / f"{source}_008.wav"
# target_wav_fn = wav_dir / target / f"{target}_008.wav"
source_wav_fn = wav_dir / source / f"{source}_023.wav"
target_wav_fn = wav_dir / target / f"{target}_023.wav"

# Features
source_wav, _ = torchaudio.load(source_wav_fn)
source_wav = source_wav.to(device)
target_wav, _ = torchaudio.load(target_wav_fn)
target_wav = target_wav.to(device)
with torch.inference_mode():
    source_feats, _ = wavlm.extract_features(source_wav, output_layer=6)
    target_feats, _ = wavlm.extract_features(target_wav, output_layer=6)
source_feats = source_feats.squeeze()
target_feats = target_feats.squeeze()
# print("source_feats shape", source_feats.shape)
# print("target_feats shape", target_feats.shape)\

# Matching
dists = fast_cosine_dist(source_feats, target_feats, device=device)
best = dists.topk(k=k_top, largest=False, dim=-1)        
linear_target = target_feats[best.indices].mean(dim=1)

# Linear regression
linear = celer.Lasso(alpha=0.3, fit_intercept=False).fit(
    source_feats.squeeze().cpu(), linear_target.cpu()
)
W = linear.coef_.T
W = torch.from_numpy(W).float().to(device)

In [11]:
display.Audio(source_wav.squeeze().cpu(), rate=16000)

In [12]:
display.Audio(target_wav.squeeze().cpu(), rate=16000)

In [14]:
wav_fn = wav_dir / source / f"{source}_057.wav"
# wav_fn = wav_dir / source / f"{source}_051.wav"
wav, _ = torchaudio.load(wav_fn)
wav = wav.to(device)
# wav = F.vad(wav, 16000)
display.Audio(wav.squeeze().cpu(), rate=16000)

In [21]:
with torch.inference_mode():
    feats, _ = wavlm.extract_features(
        wav, output_layer=6
    )

source_to_target_feats = feats @ W

with torch.inference_mode():
    wav_hat = hifigan(source_to_target_feats).squeeze(0)

In [23]:
display.Audio(wav_hat.squeeze().cpu(), rate=16000)

## LinearVC using parallel utterances (dataset)

In [31]:
exp_tag = "2024-09-16"
eval_csv = Path("data/speakersim_vctk_english.csv")
wav_dir = Path("/home/kamperh/scratch/vctk/wav")
output_dir = Path(f"/home/kamperh/scratch/linearvc/vctk/{exp_tag}")

k_top = 1
parallel_utt = "023"

output_dir.mkdir(parents=True, exist_ok=True)

In [29]:
# Projection matrices
projmats = {}
for source in tqdm(sorted(speakers)):
    for target in tqdm(sorted(speakers), leave=False):
        if source == target:
            continue

        # Features
        source_wav_fn = wav_dir / source / f"{source}_{parallel_utt}.wav"
        target_wav_fn = wav_dir / target / f"{target}_{parallel_utt}.wav"
        source_wav, _ = torchaudio.load(source_wav_fn)
        source_wav = source_wav.to(device)
        target_wav, _ = torchaudio.load(target_wav_fn)
        target_wav = target_wav.to(device)
        with torch.inference_mode():
            source_feats, _ = wavlm.extract_features(source_wav, output_layer=6)
            target_feats, _ = wavlm.extract_features(target_wav, output_layer=6)
        source_feats = source_feats.squeeze()
        target_feats = target_feats.squeeze()

        # Matching without DTW
        dists = fast_cosine_dist(source_feats, target_feats, device=device)
        best = dists.topk(k=k_top, largest=False, dim=-1)        
        linear_target = target_feats[best.indices].mean(dim=1)

        # # Matching with DTW
        # source_feats_np = source_feats.cpu().numpy()
        # target_feats_np = target_feats.cpu().numpy()
        # s = np.ascontiguousarray(np.float64(source_feats_np))
        # t = np.ascontiguousarray(np.float64(target_feats_np))
        # path, _ = _dtw.multivariate_dtw(s, t, "cosine")
        # path.reverse()
        # source_path, target_path = zip(*path)
        # i_frame = 0
        # linear_target_idx = []
        # for i_source, i_target in path:
        #     if i_source == i_frame:
        #         linear_target_idx.append(i_target)
        #         i_frame += 1
        # linear_target = target_feats_np[linear_target_idx, :]
        # linear_target = torch.from_numpy(linear_target).float()

        # W, _, _, _ = linalg.lstsq(source_feats.cpu(), linear_target.cpu())
    
        # linear = Ridge(alpha=5e3, fit_intercept=False).fit(
        #     source_feats.squeeze().cpu(), linear_target.cpu()
        # )
        # W = linear.coef_.T
        
        # linear = Lasso(alpha=0.2, fit_intercept=False).fit(
        #     source_feats.squeeze().cpu(), linear_target.cpu()
        # )
        # W = linear.coef_.T

        linear = celer.Lasso(alpha=0.3, fit_intercept=False).fit(
            source_feats.squeeze().cpu(), linear_target.cpu()
        )
        W = linear.coef_.T
        
        W = torch.from_numpy(W).float().to(device)
        projmats[f"{source}-{target}"] = W

        break
    break

  0%|          | 0/31 [00:00<?, ?it/s]

  0%|          | 0/31 [00:00<?, ?it/s]

In [32]:
print("Writing:", output_dir / "projmats.pt")
torch.save(projmats, output_dir / "projmats.pt")
projmats = torch.load(output_dir / "projmats.pt")

Writing: /home/kamperh/scratch/linearvc/vctk/2024-09-16/projmats.pt


In [34]:
print("Writing to:", output_dir)
with open(eval_csv) as f:
    for line in tqdm(f.readlines()):
        line = line.strip()
        if line[-1] == "0":
            (source, target, source_key, _, _) = line.split(",")

            source_wav_fn = (
                wav_dir / source / Path(source_key).stem
            ).with_suffix(".wav")
            source_wav, _ = torchaudio.load(source_wav_fn)
            source_wav = source_wav.to(device)
            with torch.inference_mode():
                source_feats, _ = wavlm.extract_features(
                    source_wav, output_layer=6
                )

            W_source_to_target = projmats[f"{source}-{target}"]

            source_to_target_feats = source_feats @ W_source_to_target

            with torch.inference_mode():
                wav_hat = hifigan(source_to_target_feats).squeeze(0)            

            cur_output_dir = Path(output_dir) / source_key.split("/")[0]
            cur_output_dir.mkdir(parents=True, exist_ok=True)
            output_fn = (cur_output_dir / source_key.split("/")[1]).with_suffix(
                ".wav"
            )
            torchaudio.save(output_fn, wav_hat.squeeze().cpu()[None], 16000)

            print(output_fn)
            assert False

Writing to: /home/kamperh/scratch/linearvc/vctk/2024-09-16


  0%|          | 0/9301 [00:00<?, ?it/s]

/home/kamperh/scratch/linearvc/vctk/2024-09-16/p225-p226/p225_051.wav


AssertionError: 

In [37]:
class Arguments: pass
args = Arguments()
args.format = "vctk"
args.eval_csv = eval_csv
args.converted_dir = output_dir
args.groundtruth_dir = wav_dir

print("Run:")
print(
    f"./linearvc/speaker_similarity.py --format {args.format}"
    f" {args.eval_csv} {args.converted_dir} {args.groundtruth_dir}"
)
print(
    f"./intelligibility.py --format {args.format} {args.converted_dir}"
    f" /home/kamperh/endgame/datasets/VCTK-Corpus/txt/"
)

# speaker_similarity(args)

Run:
./linearvc/speaker_similarity.py --format vctk data/speakersim_vctk_english.csv /home/kamperh/scratch/linearvc/vctk/2024-09-16 /home/kamperh/scratch/vctk/wav
./intelligibility.py --format vctk /home/kamperh/scratch/linearvc/vctk/2024-09-16 /home/kamperh/endgame/datasets/VCTK-Corpus/txt/


In [None]:
# ## kNN-VC using single utterance as reference (dataset)

# The single utterance here is the one item in the parallel utterance pairs used for LinearVC above.