In [1]:
%%capture
!pip install onnxruntime

Download model from hub


In [2]:
from huggingface_hub import notebook_login
notebook_login()

VBox(children=(HTML(value='<center> <img\nsrc=https://huggingface.co/front/assets/huggingface_logo-noborder.sv…

In [3]:
from huggingface_hub import notebook_login, hf_hub_download


model_path = hf_hub_download(repo_id="BSC-LT/vocos-mel-22khz", filename="mel_spec_22khz_univ.onnx")
config_path = hf_hub_download(repo_id="BSC-LT/vocos-mel-22khz", filename="config.yaml")

Get an audio for testing

In [4]:
!curl -O https://www.signalogic.com/melp/HAVEnoise/orig/h_orig.wav


  % Total    % Received % Xferd  Average Speed   Time    Time     Time  Current
                                 Dload  Upload   Total   Spent    Left  Speed
100 64044  100 64044    0     0  84892      0 --:--:-- --:--:-- --:--:-- 84826


run vocos onnx inference

In [5]:
import onnxruntime
import torch
import torchaudio
import torchaudio.functional as F

import yaml
from IPython.display import Audio, display


audio_input = "/content/h_orig.wav"

# load config
with open(config_path, "r") as f:
        config = yaml.safe_load(f)

params = config["feature_extractor"]["init_args"]
sample_rate = params["sample_rate"]
n_fft= params["n_fft"]
hop_length= params["hop_length"]
n_mels= params["n_mels"]
padding= params["padding"]
win_length = n_fft

# load audio
signal, fs = torchaudio.load(audio_input)

if fs != params["sample_rate"]:
    signal = F.resample(signal, fs, params["sample_rate"])

# instantiatie mel transform
mel_transform = torchaudio.transforms.MelSpectrogram(
        sample_rate=sample_rate,
        n_fft=n_fft,
        hop_length=hop_length,
        n_mels=n_mels,
        center=padding == "center",
        power=1,
        f_min=0, # to match matcha :X
        f_max=8000,
        norm="slaney",
        mel_scale="slaney"
    )

#Feature extraction
pad = win_length - hop_length
signal = torch.nn.functional.pad(signal,  (pad // 2, pad // 2), mode="reflect")
mel = torch.log(torch.clip( mel_transform(signal) , min=1e-5))

# init onnx runtime and load model
print("mel input shape", mel.shape)
sess_options = onnxruntime.SessionOptions()
model = onnxruntime.InferenceSession(model_path, sess_options=sess_options, providers=["CPUExecutionProvider"])

input_info = model.get_inputs()
for input in input_info:
  print("Name:", input.name)
  print("Shape:", input.shape)
  print("Type:", input.type)

# ONNX inference
mag, x, y = model.run(
    None,
    {
        "mels": mel.float().numpy()
    },
)

# complex spectrogram from vocos output
spectrogram = mag * (x + 1j * y)
window = torch.hann_window(win_length)

# Inverse stft
pad = (win_length - hop_length) // 2
spectrogram = torch.tensor(spectrogram)
B, N, T = spectrogram.shape

print("Spectrogram synthesized shape", spectrogram.shape)
# Inverse FFT
ifft = torch.fft.irfft(spectrogram, n_fft, dim=1, norm="backward")
ifft = ifft * window[None, :, None]

# Overlap and Add
output_size = (T - 1) * hop_length + win_length
y = torch.nn.functional.fold(
    ifft, output_size=(1, output_size), kernel_size=(1, win_length), stride=(1, hop_length),
)[:, 0, 0, pad:-pad]

# Window envelope
window_sq = window.square().expand(1, T, -1).transpose(1, 2)
window_envelope = torch.nn.functional.fold(
    window_sq, output_size=(1, output_size), kernel_size=(1, win_length), stride=(1, hop_length),
).squeeze()[pad:-pad]

# Normalize
assert (window_envelope > 1e-11).all()
audio = y / window_envelope
print("inference audio tensor:", y.shape)


mel input shape torch.Size([1, 80, 344])
Name: mels
Shape: ['batch_size', 80, 'time']
Type: tensor(float)
Spectrogram synthesized shape torch.Size([1, 513, 344])
inference audio tensor: torch.Size([1, 88064])


Let's hear the reconstruction

In [6]:
print("Original audio")
display(Audio(data=signal, rate=params["sample_rate"]))
print("Vocos reconstruction")
display(Audio(data=audio, rate=params["sample_rate"]))

Original audio


Vocos reconstruction
