In [None]:
import os
import yaml
from collections import OrderedDict

import torch
import numpy as np
from soundfile import read, write

from model import Generator_3 as Generator
from model import Generator_6 as F0_Converter
from wavenet import Synthesizer
from utils import *

In [None]:
def load_ckpt(model, ckpt_path):
    ckpt = torch.load(ckpt_path, map_location=lambda storage, loc: storage)
    try:
        model.load_state_dict(ckpt['model'])
    except:
        new_state_dict = OrderedDict()
        for k, v in ckpt['model'].items():
            new_state_dict[k[7:]] = v
        model.load_state_dict(new_state_dict)

def pad_fea(fea):
    return np.pad(fea, ((0,T-len(fea)), (0,0)), 'constant')

def create_feats(wav, gen, spk_id, config):
    if gen == 'M':
        lo, hi = 50, 250
    else:
        lo, hi = 100, 600

    if wav.shape[0] % 256 == 0:
        wav = np.concatenate((wav, np.array([1e-06])), axis=0)
    _, f0_norm = extract_f0(wav, fs, lo, hi)
    f0, sp, ap = get_world_params(wav, fs)
    f0 = average_f0s([f0])[0]
    wav_mono = get_monotonic_wav(wav, f0, sp, ap, fs)

    rhy_input = pad_fea(get_spenv(wav_mono))
    con_input = pad_fea(get_spmel(wav_mono))
    pit_input = pad_fea(quantize_f0_numpy(f0_norm)[0])
    tim_input = np.zeros((82,), dtype=np.float32)
    tim_input[int(spk_id)] = 1.0

    return (torch.FloatTensor(x).unsqueeze(0).to(device) for x in (rhy_input, con_input, pit_input, tim_input))

def convert_sp(model, rhy_input, con_input, pit_input, tim_input):
    rhy_code = model.rhythm(rhy_input)
    con_code, pit_code = model.content_pitch(torch.cat((con_input, pit_input), dim=-1), rr=False)
    sp_output = model.decode(con_code, rhy_code, pit_code, tim_input, T).cpu().numpy()[0]
    
    return sp_output

def convert_pit(model, rhy_input, con_input, pit_input):
    pit_input = torch.cat([con_input, pit_input], dim=-1)
    rhy_input = torch.nn.functional.pad(rhy_input, (0, 0, 0, T-rhy_input.size(1), 0, 0))
    pit_input = torch.nn.functional.pad(pit_input, (0, 0, 0, T-pit_input.size(1), 0, 0))
    pit_input = model(rhy_input, pit_input, rr=False) # disable random resampling at inference time

    return pit_input

In [None]:
config_name = 'spsp2-large' # or 'spsp2-small'
config = yaml.safe_load(open(f'configs/{config_name}.yaml', 'r'))
config = Dict2Class(config)
config.train = False

T = 192 # maximum number of frames in the output mel-spectrogram
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
fs = 16000
S = Synthesizer(device)
S.load_ckpt('models/wavenet_vocoder.pth')

G = Generator(config).eval().to(device)
load_ckpt(G, f'models/{config_name}-G-800000.ckpt')

config.dim_pit = config.dim_con+config.dim_pit
F = F0_Converter(config).eval().to(device)
load_ckpt(F, f'models/{config_name}-F-800000.ckpt')

In [None]:
result_dir = 'result'
if not os.path.exists(result_dir):
    os.makedirs(result_dir)
src_wav, _ = read('data/test/p225_001.wav')
tgt_wav, _ = read('data/test/p258_001.wav')

with torch.no_grad():
    conds = ['R', 'F', 'U', 'RF', 'RU', 'FU', 'RFU']
    for cond in conds:
        src_rhy, src_con, src_pit, src_tim = create_feats(src_wav, 'F', 0, config)
        tgt_rhy, tgt_con, tgt_pit, tgt_tim = create_feats(tgt_wav, 'M', 31, config)
        inp_rhy, inp_con, inp_pit, inp_tim = src_rhy, src_con, src_pit, src_tim
        if 'R' in cond:
            inp_rhy = tgt_rhy
        if 'U' in cond:
            inp_tim = tgt_tim
        if 'F' in cond:
            inp_pit = convert_pit(F, src_rhy, tgt_con, tgt_pit)
        out_sp = convert_sp(G, inp_rhy, inp_con, inp_pit, inp_tim)
        out_wav = S.spect2wav(out_sp)
        write(os.path.join(result_dir, f'p225_p258_001_{cond}.wav'), out_wav, fs)