In [99]:
import os
import pickle
from model_bl import D_VECTOR
from collections import OrderedDict
import numpy as np
import torch
import soundfile as sf
from scipy import signal
from librosa.filters import mel
from numpy.random import RandomState
from scipy.signal import get_window
from math import ceil
from model_vc import Generator
from synthesis import build_model, wavegen
import librosa

In [77]:
class SpeakerEmbedder():
    def __init__(self,model_path='3000000-BL.ckpt'):
        self.C = D_VECTOR(dim_input=80, dim_cell=768, dim_emb=256).eval().cuda()
        c_checkpoint = torch.load(model_path)
        new_state_dict = OrderedDict()
        for key, val in c_checkpoint['model_b'].items():
            new_key = key[7:]
            new_state_dict[new_key] = val
        C.load_state_dict(new_state_dict)
    
    @staticmethod
    def melspec(path):
        def butter_highpass(cutoff, fs, order=5):
            nyq = 0.5 * fs
            normal_cutoff = cutoff / nyq
            b, a = signal.butter(order, normal_cutoff, btype='high', analog=False)
            return b, a


        def pySTFT(x, fft_length=1024, hop_length=256):

            x = np.pad(x, int(fft_length//2), mode='reflect')

            noverlap = fft_length - hop_length
            shape = x.shape[:-1]+((x.shape[-1]-noverlap)//hop_length, fft_length)
            strides = x.strides[:-1]+(hop_length*x.strides[-1], x.strides[-1])
            result = np.lib.stride_tricks.as_strided(x, shape=shape,
                                                     strides=strides)

            fft_window = get_window('hann', fft_length, fftbins=True)
            result = np.fft.rfft(fft_window * result, n=fft_length).T

            return np.abs(result)    


        mel_basis = mel(16000, 1024, fmin=90, fmax=7600, n_mels=80).T
        min_level = np.exp(-100 / 20 * np.log(10))
        b, a = butter_highpass(30, 16000, order=5)

        x, fs = sf.read(path)
        # Remove drifting noise
        y = signal.filtfilt(b, a, x)
        # Ddd a little random noise for model roubstness
        # prng = RandomState(225) 
        # wav = y * 0.96 + (prng.rand(y.shape[0])-0.5)*1e-06
        wav = y
        # Compute spect
        D = pySTFT(wav).T
        # Convert to mel and normalize
        D_mel = np.dot(D, mel_basis)
        D_db = 20 * np.log10(np.maximum(min_level, D_mel)) - 16
        S = np.clip((D_db + 100) / 100, 0, 1)    
        S = S.astype(np.float32)
        return S
    
    def __call__(self, path):
        tmp = SpeakerEmbedder.melspec(path)
        len_crop = 128
        left = np.random.randint(0, tmp.shape[0]-len_crop)
        melsp = torch.from_numpy(tmp[np.newaxis, left:left+len_crop, :]).cuda()
        with torch.no_grad():
            emb = self.C(melsp)
        return emb.squeeze().cpu().numpy()

In [82]:
class SpeechConverter():
    def __init__(self):
        self.device = 'cuda:0'
        self.G = Generator(32,256,512,32).eval().to(self.device)

        g_checkpoint = torch.load('autovc.ckpt', map_location="cuda:0")
        self.G.load_state_dict(g_checkpoint['model'])
    
    def __call__(self, x_org, emb_org, emb_trg):
        x_org, len_pad = SpeechConverter.pad_seq(x_org)
        uttr_org = torch.from_numpy(x_org[np.newaxis, :, :]).to(self.device)
        emb_org = torch.from_numpy(emb_org[np.newaxis, :]).to(self.device)
        emb_trg = torch.from_numpy(emb_trg[np.newaxis, :]).to(self.device)
        with torch.no_grad():
            _, x_identic_psnt, _ = self.G(uttr_org, emb_org, emb_trg)
        if len_pad == 0:
            uttr_trg = x_identic_psnt[0, 0, :, :].cpu().numpy()
        else:
            uttr_trg = x_identic_psnt[0, 0, :-len_pad, :].cpu().numpy()
            
        return uttr_trg
    
    @staticmethod
    def pad_seq(x, base=32):
        len_out = int(base * ceil(float(x.shape[0])/base))
        len_pad = len_out - x.shape[0]
        assert len_pad >= 0
        return np.pad(x, ((0,len_pad),(0,0)), 'constant'), len_pad

In [94]:
class Vocoder():
    def __init__(self):
        device = torch.device("cuda")
        model = build_model().to(device)
        checkpoint = torch.load("checkpoint_step001000000_ema.pth")
        model.load_state_dict(checkpoint["state_dict"])
        self.model = model

    def __call__(self, spec):
        waveform = wavegen(self.model, spec)
        return waveform

In [133]:
SE = SpeakerEmbedder()
x_org = SpeakerEmbedder.melspec('/home/tony/D/corpus/data_aishell/wav/test/BAC009S0764W0491.wav')
emb_org = SE('/home/tony/D/corpus/data_aishell/wav/test/BAC009S0764W0491.wav')
emb_trg = SE("/home/tony/D/corpus/data_aishell/wav/test/BAC009S0764W0491.wav")

In [134]:
SP=SpeechConverter()
x_trg = SP(x_org,emb_org,emb_trg)

In [None]:
VD = Vocoder()
wav = VD(x_trg)

 30%|██▉       | 30629/103680 [04:34<11:00, 110.53it/s]

In [None]:
from io import BytesIO

In [None]:
f = BytesIO()
librosa.output.write_wav(f, wav, sr=16000)

In [None]:
from IPython.display import Audio

In [None]:
Audio(data=f.read(), rate=16000)

In [None]:
Audio(filename='/home/tony/D/corpus/data_aishell/wav/test/BAC009S0764W0491.wav', rate=16000)

In [None]:
Audio(filename='/home/tony/D/corpus/data_aishell/wav/test/BAC009S0768W0185.wav', rate=16000)