In [3]:
import argparse
from glob import glob
import os

import torch
import tqdm
from omegaconf import OmegaConf
from scipy.io.wavfile import write

from model.generator import Generator
from utils.stft import TacotronSTFT
from utils.utils import read_wav_np



checkpoint = torch.load("/home/alex/projects/univnet/models/univnet_bwe_0041.pt")
hp = OmegaConf.create(checkpoint["hp_str"])

stft = TacotronSTFT(
    hp.audio.filter_length,
    hp.audio.hop_length//2,
    hp.audio.win_length//2,
    hp.audio.n_mel_channels,
    hp.audio.sampling_rate//2,
    hp.audio.mel_fmin,
    hp.audio.mel_fmax,
    center=False,
)


model = Generator(hp).cuda()
saved_state_dict = checkpoint["model_g"]
new_state_dict = {}

for k, v in saved_state_dict.items():
    try:
        new_state_dict[k] = saved_state_dict["module." + k]
    except:
        new_state_dict[k] = v
model.load_state_dict(new_state_dict)
model.eval(inference=True)



Removing weight norm...


In [6]:
filepath = "/home/alex/projects/univnet/input_audio/BORKA_BARKER_03.wav"
filename = os.path.split(filepath)[1]
sr, audio = read_wav_np(filepath)

wav = torch.from_numpy(audio).unsqueeze(0)
mel = stft.mel_spectrogram(wav)

with torch.no_grad():

    if len(mel.shape) == 2:
        mel = mel.unsqueeze(0)
    mel = mel.cuda()

    audio = model.inference(mel)
    audio = audio.cpu().detach().numpy()


    out_path = os.path.join("/home/alex/projects/univnet/output_test", filename)
    write(out_path, hp.audio.sampling_rate, audio)
