Measure how much time each step takes (WavLM, kNN, HiFiGAN)

In [5]:
import yaml
from time import perf_counter
import json

import torchaudio
import torch
from torch.multiprocessing import Queue
from IPython.display import display, Audio
import onnxruntime as ort

from stream_processing.models.knnvc.knnvc import convert_vecs
from stream_processing.models.knnvc.wavlm.model import WavLM, WavLMConfig
from stream_processing.models.knnvc.hifigan import Generator, AttrDict


In [4]:
audiofile = "/Users/cafr02/datasets/LibriSpeech/dev-clean/1272/128104/1272-128104-0000.flac"
audio = torchaudio.load(audiofile)[0].squeeze()
print(audio.shape)
display(Audio(audio.numpy(), rate=16000))

torch.Size([93680])


## Naive Torch Models

In [7]:
# initialize the WavLM model
hifigan_cfg = "/Users/cafr02/repos/spkanon/checkpoints/knnvc/hifigan.json"
hifigan_ckpt = "/Users/cafr02/repos/spkanon/checkpoints/knnvc/hifigan.pt"
wavlm_ckpt = "/Users/cafr02/repos/spkanon/checkpoints/WavLM-Large.pt"
wavlm_layer = 6
n_neighbors = 4
target_feats = torch.load("../target_feats/135887.pt")

ckpt = torch.load(wavlm_ckpt, map_location="cpu")
wavlm = WavLM(WavLMConfig(ckpt["cfg"]))
wavlm.load_state_dict(ckpt["model"])
wavlm.eval()

# initialize the HiFiGAN model
hifigan = Generator(AttrDict(json.load(open(hifigan_cfg))))
hifigan.load_state_dict(
    torch.load(hifigan_ckpt, map_location="cpu")["generator"]
)
hifigan.eval()
hifigan.remove_weight_norm()



Removing weight norm...


In [8]:
chunk_size = 3200
out = list()
times = {"wavlm": 0, "conv": 0, "hifigan": 0}
for chunk in range(0, len(audio), chunk_size):
    audio_chunk = audio[chunk : chunk + chunk_size]
    with torch.inference_mode():
        ts_0 = perf_counter()
        wavlm_feats = wavlm.extract_features(
            audio_chunk.unsqueeze(0), output_layer=wavlm_layer
        )[0][0]
        ts_1 = perf_counter()
        times["wavlm"] += ts_1 - ts_0
        conv_feats = convert_vecs(wavlm_feats, target_feats, n_neighbors)
        ts_2 = perf_counter()
        times["conv"] += ts_2 - ts_1
        out.append(hifigan(conv_feats.unsqueeze(0)).squeeze())
        ts_3 = perf_counter()
        times["hifigan"] += ts_3 - ts_2

out = torch.cat(out)
display(Audio(out.numpy(), rate=16000))

print(times)
    

{'wavlm': 1.2109307470045678, 'conv': 0.12665192000076786, 'hifigan': 1.0310972499983109}


I've tried `torch.compile`, but WavLM does not work:

```log
---------------------------------------------------------------------------
BackendCompilerFailed                     Traceback (most recent call last)
Cell In[8], line 1
----> 1 convert_compiled(audio, converter)

File ~/repos/stream_processing/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:328, in _TorchDynamoContext.__call__.<locals>._fn(*args, **kwargs)
    326 dynamic_ctx.__enter__()
    327 try:
--> 328     return fn(*args, **kwargs)
    329 finally:
    330     set_eval_frame(prior)

Cell In[7], line 9
      7 with torch.inference_mode():
      8     start = perf_counter()
----> 9     wavlm_feats = converter.wavlm.extract_features(
     10         audio_chunk.unsqueeze(0), output_layer=cfg["wavlm_layer"]
     11     )[0][0]
     12     conv_feats = convert_vecs(wavlm_feats, converter.target_feats, cfg["n_neighbors"])
     13     out.append(converter.hifigan(conv_feats.unsqueeze(0)).squeeze())

File ~/repos/stream_processing/venv/lib/python3.10/site-packages/torch/_dynamo/eval_frame.py:490, in catch_errors_wrapper.<locals>.catch_errors(frame, cache_entry, frame_state)
    487             return hijacked_callback(frame, cache_entry, hooks, frame_state)
    489 with compile_lock, _disable_current_modes():
--> 490     return callback(frame, cache_entry, hooks, frame_state)
...

You can suppress this exception and fall back to eager by setting:
    import torch._dynamo
    torch._dynamo.config.suppress_errors = True
```

## ONNX Runtime

In [18]:
session_wavlm = ort.InferenceSession(f"../onnx/wavlm_{chunk_size}.onnx")
session_hifigan = ort.InferenceSession(f"../onnx/hifigan_{chunk_size}.onnx")

out = list()
times = {"wavlm": 0, "conv": 0, "hifigan": 0}
for chunk in range(0, len(audio), chunk_size):
    audio_chunk = audio[chunk : chunk + chunk_size]
    if audio_chunk.shape[0] < chunk_size:
        audio_chunk = torch.cat(
            [audio_chunk, torch.zeros(chunk_size - audio_chunk.shape[0])]
        )
    ts_0 = perf_counter()
    wavlm_in = audio_chunk.unsqueeze(0).numpy()
    wavlm_feats = session_wavlm.run(["output"], {"input": wavlm_in})[0]
    wavlm_feats = torch.tensor(wavlm_feats.squeeze())
    ts_1 = perf_counter()
    times["wavlm"] += ts_1 - ts_0
    conv_feats = convert_vecs(
        wavlm_feats, target_feats, n_neighbors
    ).unsqueeze(0)
    ts_2 = perf_counter()
    times["conv"] += ts_2 - ts_1
    hifigan_out = session_hifigan.run(
        None, {"input": conv_feats.detach().numpy()}
    )[0].squeeze()
    ts_3 = perf_counter()
    times["hifigan"] += ts_3 - ts_2
    out.append(torch.tensor(hifigan_out, dtype=torch.float32))

out = torch.cat(out)
display(Audio(out, rate=16000))
print(times)

{'wavlm': 0.40129758199691423, 'conv': 0.14110283300578885, 'hifigan': 0.6870545409947226}
