In [1]:
from __future__ import absolute_import, division, print_function, unicode_literals

import glob
import os
import argparse
import json
import torch
from scipy.io.wavfile import write
from env import AttrDict
from meldataset import mel_spectrogram, MAX_WAV_VALUE
from models import Generator
from stft import TorchSTFT


h = None
device = None


def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict


def get_mel(x):
    return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)


def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

In [2]:
config_file = './config_v1.json'
with open(config_file) as f:
    data = f.read()

global h
json_config = json.loads(data)
h = AttrDict(json_config)

torch.manual_seed(h.seed)
global device
if torch.cuda.is_available():
    torch.cuda.manual_seed(h.seed)
    device = torch.device('cpu')
else:
    device = torch.device('cpu')

generator = Generator(h).to(device)
stft = TorchSTFT(filter_length=h.gen_istft_n_fft, hop_length=h.gen_istft_hop_size, win_length=h.gen_istft_n_fft).to(device)

state_dict_g = load_checkpoint('./checkpoint/g_00002400', device)
generator.load_state_dict(state_dict_g['generator'])

generator.eval()
generator.remove_weight_norm()

Loading './checkpoint/g_00002400'
Complete.
Removing weight norm...


In [3]:
from scipy.io.wavfile import read

def load_wav(full_path):
    sampling_rate, data = read(full_path)
    if max(data)<=1:
        data = data*MAX_WAV_VALUE
    return data.astype("int16"), sampling_rate

In [4]:
stft = stft.to('cuda')
def inference(filename):
    with torch.no_grad():
        wav, sr = load_wav(filename)
        wav = wav / MAX_WAV_VALUE
        wav = torch.FloatTensor(wav).to(device)
        x = get_mel(wav.unsqueeze(0))
        spec, phase = generator(x, torch.LongTensor([x.size(2)]).to(device))
        y_g_hat = stft.inverse(spec, phase)
        audio = y_g_hat.squeeze()
        audio = audio * MAX_WAV_VALUE
        audio = audio.cpu().numpy().astype('int16')
        
        return audio

In [6]:
from IPython.display import Audio as Audio 

y1 = inference("./sample/si965.wav")
y2 = inference("./sample/p225_001.wav")
y3 = inference("./sample/p226_002.wav")



In [7]:
Audio(y1, rate=16000)

In [8]:
Audio(y2, rate=16000)

In [9]:
Audio(y3, rate=16000)