Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use WORLD vocoder #9

Closed
reuben opened this issue Mar 27, 2018 · 70 comments
Closed

Use WORLD vocoder #9

reuben opened this issue Mar 27, 2018 · 70 comments
Labels
help wanted Extra attention is needed improvement a new feature

Comments

@reuben
Copy link
Contributor

reuben commented Mar 27, 2018

No description provided.

@erogol erogol added the help wanted Extra attention is needed label Apr 13, 2018
@erogol
Copy link
Contributor

erogol commented Apr 13, 2018

Since I cannot find a handy WORLD interface for feature extraction, it has no progress. If anyone finds something out pls let me know.

Simple WaveNet is really slow to be used especially for the test time. New Parallel WaveNet paper seems interesting but it is also a big work to train proposed teacher-student Framework. Before starting this phase, I expect to finish some other experiments.

@m-toman
Copy link
Contributor

m-toman commented Apr 13, 2018

I've used this python wrapper once: https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder

Alternatively there is this helper script in Merlin that might be useful: https://github.com/CSTR-Edinburgh/merlin/blob/master/misc/scripts/vocoder/world/extract_features_for_merlin.py

@erogol
Copy link
Contributor

erogol commented Apr 14, 2018

@m-toman thnks for pointers. How was the experience with the first one. I've tried it on a single audio file but synthesis on the extracted features was creating aberrations on the generated audio.

Do you also have any results with World? Would you say things are better with it?

@m-toman
Copy link
Contributor

m-toman commented Apr 15, 2018

@erogol
I did not try your.codebase yet to be honest. But my experience with world:

Like in the merlin codebase I also replaced F0 extraction with REAPER, which was a huge improvement, at least over the older method in WORLD - Dio.

I've only briefly tried the internal spectrum compression method but did not have any luck with it and resorted back to using MGCs/MFCCs.

It's fast enough for live synthesis on mobile (especially via the streaming implementation)

It generally sounds pretty good, buts I suspect we can get better results with neural vocoders in future (although they typically use the acoustic features for conditioning as well). But the fast Wavenet method seems a bit bloated to me.
I'd love to try WaveRNN at some point (https://arxiv.org/abs/1802.08435) but currently can't find the time.

@erogol
Copy link
Contributor

erogol commented Apr 16, 2018

@m-toman thanks for sharing your bits&pieces. I thinks soon I gonna try WORLD and put the results here.

@erogol erogol added the improvement a new feature label Apr 23, 2018
@erogol
Copy link
Contributor

erogol commented Apr 24, 2018

Wavenet with these values impractical to use r9y9/wavenet_vocoder#28 (comment)

erogol pushed a commit that referenced this issue Apr 25, 2018
erogol pushed a commit that referenced this issue Apr 25, 2018
@stevemurr
Copy link

Hi @erogol , I'd like to help you integrate the WORLD vocoder. It seems we need a way to extract the params needed for the vocoder. One option for extraction and synthesis could be pyworld library.

import pyworld as pw
import soundfile as sf

# numpy_array, sample_rate
x, fs = sf.read(path)
# f0, spectrogram, aperiodicities
f0, sp, ap = pw.wav2world(x, fs)

# numpy_array
y = pw.synthesize(f0, sp, ap, fs, pw.default_frame_period)

sf.write("/path/to/audio.wav", y, fs)

Since the linear spectrogram is already being predicted by the post net, it seems we need to also predict f0 and aperiodicities. Is it possible to pass f0 and aperiodicity targets and predict them using a post net as well?

@erogol
Copy link
Contributor

erogol commented May 1, 2018

@stevemurr thanks for the post. I also just start to use pyworld and test its performance in different settings. I can share the notebook of these experiments if you are also interested.

I did extract the features for WORLD and now writing a dataloader. After I finish this, I can push the branch to your provision.

I think what you suggest it possible but I am not sure about the quality. Only way to check is to try :). It would be worthy since WORLD is much faster compared to Griffin-Lim with no quality sacrifice in general. The only down side is to extract the features at first for any dataset since it is too slow to be done on the fly.

@stevemurr
Copy link

@erogol Intriguing! I'd love to check out the notebook with your comparisons. I'm very curious about the viability of WORLD as an intermediate choice as a vocoder until fast neural vocoders become ubiquitous.

I've achieved comparable results to Google's paper using https://github.com/Rayhane-mamah/Tacotron-2 and r9y9's WaveNet implementation. The downside being I had to train the wave net to roughly 1.6 million steps which took a couple weeks on a single 1080. The upside being it's useful for offline synthesis.

@erogol
Copy link
Contributor

erogol commented May 2, 2018

@stevemurr FYInterest: https://gist.github.com/erogol/92cdeca0e12c9ea3e79e518111b354c7
https://github.com/mozilla/TTS/tree/world_new with only a WORLD feature extraction script. So the next thing is to writing a data loader. I am up to this as soon as I finish my current experiment.

@erogol
Copy link
Contributor

erogol commented May 2, 2018

What I observe, encoding with f0 tunning by harvest is the best.

@erogol
Copy link
Contributor

erogol commented May 3, 2018

Data loader added...

@stevemurr
Copy link

@erogol Thanks for the notebook examples and I agree with your assessment harvest. Taking a look at the data loader, it seems some modifications to train.py will do the trick. Let me know if there are any specific tasks you need help with for the WORLD integration.

@erogol
Copy link
Contributor

erogol commented May 6, 2018

@stevemurr I coded the train.py with a small lazy testing. Now I try to update the network architecture to allocate WORLD features. One of the questions, how to replace intermediate mel-spectrogram prediction. Do you think it makes sense to use mel-scale spectral envelope at that stage?

I also try to normalize WORLD features in a non-disruptive way for efficient training. Any ideas on this?

@stevemurr
Copy link

@erogol Sorry for the late reply!

I'd like to share a tutorial from one of r9y9's repos - he describes a traditional TTS process using linguistic features and acoustic features to build a duration model and an acoustic model - this can hopefully serve as a guide in targeting WORLD. Since we are using the encoder/decoder architecture we can hopefully learn the alignments from text to WORLD targets. A portion of the tutorial does cover data preparation for WORLD including normalization strategies.

Before training neural networks, we need to normalize data. Following Merlin’s demo script, we will apply min/max normalization for linguistic features and mean/variance normalization to duration/acoustic features. You can compute necessary statistics using nnmnkwii.util.minmax and nnmnkwii.util.meanvar. The comptuation is online, so we can use the functionality for any large dataset.

Following this I assume we should attempt mean/variance normalization for the WORLD features - a function is provided for this in his nnmnkwii library at nnmnkwii.util.meanvar.

how to replace intermediate mel-spectrogram prediction. Do you think it makes sense to use mel-scale spectral envelope at that stage?

Is there a difference between spectrogram and spectral envelope?

I have an idea for using multiple encoder and decoders that I wanted to get your thoughts on - hopefully it's not too crazy sounding :). Currently the network predicts a Mel spectrogram which works in discovering correct alignments with the text. What if we leave the current architecture as is but create a series of encoder/decoders after the Mel is predicted. We then encode the predicted Mel and decode it as the target f0 and do the same for ap. In training we would probably want to pass the ground truth Mel to the f0 and ap encoders. Let me know your thoughts on this or if you have a more optimal solution.

@erogol
Copy link
Contributor

erogol commented May 16, 2018

@stevemurr unfortunately I am busy with the other experiments for now, will return to this thread after a while.

@Maxxiey
Copy link

Maxxiey commented Sep 17, 2018

Hi,@stevemurr,I am thinking another way to integerate tacotron with WORLD vocoder.What if I just replace the Mel spectogram with WORLD parameters(using pyworld to extract them) and let the model predict these parameters directly? I am not sure if it can work and I am going to give it a try.Dose this idea ever occur to you and have you succeed in integerating tacotron with WORLD vocoder?

Thanks~

@m-toman
Copy link
Contributor

m-toman commented Sep 17, 2018

@Maxxiey I suspect this is a reasonable approach, after all that's what systems like Merlin do, or also the a bit older papers by Heiga Zen (although with the Vocaine vocoder) - just put mcep, f0, bap (potentially V/UV flag depending on which F0 extractor you use) in a single vector, usually with delta and delta-delta features (+MLPG afterwards, not sure if necessary with Taco).

Other options:
I've hooked up the WaveRNN repo with another Taco implementation here: https://github.com/m-toman/tacorn
Then of course we have the nvidia wavenet implementation here https://github.com/NVIDIA/nv-wavenet
and interesting slides about wavenet on CPU: http://on-demand.gputechconf.com/gtc/2017/presentation/s7544-andrew-gibiansky-efficient-inference-for-wavenet.pdf

@Maxxiey
Copy link

Maxxiey commented Sep 17, 2018

@m-toman Thank you very much for your reply.Your opnion encourages me,a newbie in TTS field, a lot.
The idea to put f0,spectral_envelope and aperiodicity into a single vector is quite straightforward and I think that someone had already working on it: https://github.com/geneing/deepvoice3_pytorch.I am going to try this in the coming days, hope to get somewhere(whether the result is good or not).Thank you again for your help~

@erogol
Copy link
Contributor

erogol commented Sep 19, 2018

@m-toman thanks for visiting TTS. Yes I was planing to do that. Even I wrote a script to extract all WORLD features but could not find time to go further.

DeepVoice3 paper uses WORLD and reports very close results to NN based vocoder. It would be more preferential over WaveNet since it'd be easier to train and perform inference. I believe WORLD would give better results, at least without network in the loop, WORLD performs better recovery then Griffin-Lim algorithm.

This branch is outdated but it might be useful for you to take a look at https://github.com/mozilla/TTS/tree/world_new/scripts

@erogol
Copy link
Contributor

erogol commented Sep 19, 2018

@m-toman How was the quality and run-time with WaveRNN ?

@erogol
Copy link
Contributor

erogol commented Sep 19, 2018

@m-toman
Copy link
Contributor

m-toman commented Sep 19, 2018

@erogol I only tried the adapted model by fatchord... While I don't have actual numbers, I would say it was probably about 1 minute for a longer sentence instead of 10 minutes with the Wavenet implementation by r9y9. On a GTX 1080Ti.

The GTA samples produced during the training were pretty good (https://www.dropbox.com/sh/2gtunx8d1r92fqb/AADh9CJEtvHnQ7YlwNClk8X5a?dl=0&m=) , I wasn't that happy with the actual end to end synthesis results, but perhaps it was the fault of the Tacotron model - didn't train it very long, perhaps I can produce a couple samples soon.

My main issue is that so much more work is going on with Wavenet. And on the other side of the spectrum - we don't have to train WORLD for every speaker and it also easily compiles and runs fast enough on lots of platforms. That's why I think it would certainly be interesting to try out.
At the moment I also don't have the time but perhaps later on, in case @Maxxiey hasn't already done it then ;).

By chance, do you know what the main differences in your tacotron implementation vs https://github.com/Rayhane-mamah/Tacotron-2 are? (except of course Tensorflow vs PyTorch and the Wavenet integration)

@Maxxiey
Copy link

Maxxiey commented Sep 20, 2018

@erogol Thanks, this will come in handy~

@erogol
Copy link
Contributor

erogol commented Sep 20, 2018

@m-toman I've not checked https://github.com/Rayhane-mamah/Tacotron-2 but I don't use Tacotron2 model. The only similarity is to use Location Sensitive Attention the rest is the same old Tacotron. I found no improvement to use other alternative layers proposed in TC2. It worked for them probably since they end the system with WaveNet instead of GL.

@erogol erogol changed the title Use WORLD or WaveNet as vocoder Use WORLD vocoder Mar 11, 2019
@erogol
Copy link
Contributor

erogol commented Mar 11, 2019

@tsungruihon it looks erroneous to me. Are you sure your input and output set correctly at any part of your network?

@begeekmyfriend
Copy link

It really strange that whether it will learn alignment with WORLD acoustic features that might depend on dataset. For some dataset it never converge but for others it learns quickly. I doubt the acoustic features of WORLD vocoder very sensitive for training.
step-25000-align
step-4000-align (2)

@OswaldoBornemann
Copy link

@erogol @begeekmyfriend here is the latest alignment graph. Yesterday i have tried to generate audio using pyworld , but the wav array's numbers are all zeros. Today i will check it out.
individualImage

@erogol
Copy link
Contributor

erogol commented Mar 12, 2019

@tsungruihon it looks better. How many epochs?

@OswaldoBornemann
Copy link

@erogol @begeekmyfriend nearly 140k steps.
But now the strangest thing is that i could not synthesize audio. I compared target and predict output given the same input like text = 'la4 wu3 you2 tu4'. The loss between target and prediction is small, but when i denormalize lf0 to f0, the values of f0_pred are quite large.

world_targets = torch.cat([lf0s_input.unsqueeze_(-1), mgcs_input, baps_input], dim=-1)
world_out, alignments, stop_tokens = model.forward(chars_var.long(), world_targets)
criterion = L1LossMasked()
print(criterion(world_out, world_targets, mgcs_length))
[output]
tensor(0.4070, grad_fn=<DivBackward0>)

Below is how i extract lf0 output, mgc_outputs and bap_outputs like @begeekmyfriend did. Here the n_mgc is 60.

lf0_outputs = world_out[:, :, 0]
mgc_outputs = world_out[:, :, 1 : 1 + hp.n_mgc]
bap_outputs = world_out[:, :, 1 + hp.n_mgc:]

lf0_targets = world_targets[:, :, 0]
mgc_targets = world_targets[:, :, 1 : 1 + hp.n_mgc]
bap_targets = world_targets[:, :, 1 + hp.n_mgc:]

print(sum(sum(lf0_outputs - lf0_targets)))
[output]
tensor(-491.7143, grad_fn=<AddBackward0>)

@OswaldoBornemann
Copy link

@erogol @begeekmyfriend i just plot the lf0_predict, lf0_target, mgc_predict, mgc_target , bap_predict and bap_target, trying to figure out what's going on. The result tells me something wrong.

Screen Shot 2019-03-12 at 6 07 44 PM

Screen Shot 2019-03-12 at 6 07 48 PM

Screen Shot 2019-03-12 at 6 07 55 PM

Screen Shot 2019-03-12 at 6 07 59 PM

Screen Shot 2019-03-12 at 6 08 03 PM

Screen Shot 2019-03-12 at 6 08 09 PM

@begeekmyfriend
Copy link

begeekmyfriend commented Mar 12, 2019

@tsungruihon Maybe you need some post processing before synthesis. See my code. By the way, you can do some resynth with WORLD vocoder only and record the feature values for checking out. https://github.com/Rayhane-mamah/Tacotron-2/files/2713952/world_vocoder_resynth_scripts.zip Here is resynth script.

@OswaldoBornemann
Copy link

OswaldoBornemann commented Mar 12, 2019

@begeekmyfriend sure my friend. I followed your code like below

lf0_outputs = world_out[:, :, 0]
mgc_outputs = world_out[:, :, 1:61]
bap_outputs = world_out[:, :, 61:]

lf0 = lf0_outputs[0].data.numpy()
lf0 = np.where(lf0 < 1, 0.0, lf0)
f0_pred = f0_denormalize(lf0)

mgc_targets = mgc_outputs[0].data.numpy()
sp_pred = sp_denormalize(mgc_targets)

bap_targets = bap_outputs[0].data.numpy()
ap_pred = ap_denormalize(bap_targets, lf0)

But when i ran wav = pw.synthesize(f0_pred, sp_pred, ap_pred, 48000), the jupyter notebook kernel restart automatically every time..

@begeekmyfriend
Copy link

@tsungruihon There is core dump in pw.synthesize method. Please save those feature values and use resynth script for test.

@OswaldoBornemann
Copy link

@erogol @begeekmyfriend i would like to share the prediction of the model, lf0, mgc and bap and the denormalize result f0, sp and ap. It still show Segmentation fault (core dumped).

world_features_prediction.zip

@begeekmyfriend
Copy link

begeekmyfriend commented Mar 18, 2019

step-42000-align
step-48000-align
step-60000-align
mandarin_male_world.zip
The training is only with 80K steps and still undergoing. In my humble opinion, We need to some normalization jobs for WORLD vocoder acoustic feature values for quicker alignment. If it comes to core dump in WORLD library, just go on training @tsungruihon.
BTW, I am using Tacotron-2 with this branch.

@OswaldoBornemann
Copy link

OswaldoBornemann commented Mar 18, 2019

Thanks my friend. Really grateful and appreciated.! @begeekmyfriend

@begeekmyfriend
Copy link

WORLD feature extraction from GanTTS helps convergence. Any feedback is welcome! begeekmyfriend/Tacotron-2@e40a7b7
step-13000-align

@begeekmyfriend
Copy link

Here is biaobei mandarin demo from T2 + WORLD. The f0 feature value prediction is tough for this model.
xmly_biaobei_world.zip

@mrgloom
Copy link

mrgloom commented Apr 29, 2019

@begeekmyfriend
I wonder why did you not used default setting like in pyworld demo? There is something specific for for TTS task?
https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder/blob/master/demo/demo.py

@begeekmyfriend
Copy link

The part of feature extraction is derived from gantts project.

@begeekmyfriend
Copy link

begeekmyfriend commented May 16, 2019

Never recommend this vocoder anymore. The feature values are too sensitive to be predicted. In my humble opinion mel spectrogram plus neural network vocoder such is WaveRNN is the most proper solution for TTS so far. I give up.

By the way the implementation of WORLD + Tacotron2 is still kept in my fork branch

@carlfm01
Copy link

@begeekmyfriend did you ever tried using LPCNET+Tacotron2?

@carlfm01
Copy link

@tsungruihon I saw you comment on how to connect LPCNET and Tacotron, this guide maybe will be useful for you MlWoo/LPCNet@324b212

@OswaldoBornemann
Copy link

@carlfm01 thanks a lot.!

@begeekmyfriend
Copy link

I have tried WaveRNN.
wavernn_mandarin_male_22050.zip

@erogol erogol closed this as completed May 27, 2019
@mrgloom
Copy link

mrgloom commented May 30, 2019

Test on robustness of smoothing features from WORLD vocoder(as I understand L1 loss will introduce smoothness of predicted time series):

Based on: https://github.com/JeremyCCHsu/Python-Wrapper-for-World-Vocoder/blob/master/demo/demo.py

As source this file used: https://google.github.io/tacotron/publications/tacotron2/demos/romance_gt.wav
From tacotron2 demo page: https://google.github.io/tacotron/publications/tacotron2/

Results: reconstruct_smooth_features_example.zip

Here I just use box filter:

import os
from shutil import rmtree

import numpy as np
import soundfile as sf
import pyworld as pw
from scipy import signal

import matplotlib
matplotlib.use('Agg')
import matplotlib.pyplot as plt

def savefig(filename, figlist, log=True):
    EPSILON = 1e-8
    n = len(figlist)
    f = figlist[0]
    if len(f.shape) == 1:
        plt.figure()
        for i, f in enumerate(figlist):
            plt.subplot(n, 1, i+1)
            if len(f.shape) == 1:
                plt.plot(f)
                plt.xlim([0, len(f)])
    elif len(f.shape) == 2:
        plt.figure()
        for i, f in enumerate(figlist):
            plt.subplot(n, 1, i+1)
            if log:
                x = np.log(f + EPSILON)
            else:
                x = f + EPSILON
            plt.imshow(x.T, origin='lower', interpolation='none', aspect='auto', extent=(0, x.shape[0], 0, x.shape[1]))
    else:
        raise ValueError('Input dimension must < 3.')
    plt.savefig(filename)


def reconstruct_smooth_features():
    output_dir = 'temp'
    if os.path.isdir(output_dir):
        rmtree(output_dir)
    os.mkdir(output_dir)

    x, fs = sf.read('utterance/romance_gt.wav')
    print('x.shape', x.shape) #
    print('fs', fs) #

    f0, sp, ap = pw.wav2world(x, fs)
    print('f0.shape', f0.shape) #
    print('sp.shape', sp.shape) #
    print('ap.shape', ap.shape) #

    def smooth_1d(y, n=3):
        box = np.ones(n) / n
        y_smooth = np.convolve(y, box, mode='same')
        return y_smooth

    def smooth_2d(y, n=3):
        box = np.ones((n,n)) / (n*n)
        y_smooth = signal.convolve(y, box, mode='same')
        return y_smooth

    y = pw.synthesize(f0, sp, ap, fs, pw.default_frame_period)
    sf.write(os.path.join(output_dir, 'original_reconstruction.wav'), y, fs)

    for kernel_size in [3,5,7,9,15,31,65]:
        print('-'*60)
        f0_smooth = smooth_1d(np.copy(f0), n=kernel_size)
        sp_smooth = smooth_2d(np.copy(sp), n=kernel_size)
        ap_smooth = smooth_2d(np.copy(ap), n=kernel_size)

        savefig(os.path.join(output_dir, str(kernel_size).zfill(4)+'_f0.png'), [f0, f0_smooth])
        savefig(os.path.join(output_dir, str(kernel_size).zfill(4)+'_sp.png'), [sp, sp_smooth])
        savefig(os.path.join(output_dir, str(kernel_size).zfill(4)+'_ap.png'), [ap, ap_smooth])

        y_smooth = pw.synthesize(f0_smooth, sp_smooth, ap_smooth, fs, pw.default_frame_period)
        sf.write(os.path.join(output_dir, str(kernel_size).zfill(4)+'_smoothed_reconstruction.wav'), y_smooth, fs)

        print('kernel_size:', kernel_size)
        print('np.max(np.abs(f0-f0_smooth))', round(np.max(np.abs(f0-f0_smooth)), 2))
        print('np.max(np.abs(sp-sp_smooth))', round(np.max(np.abs(sp-sp_smooth)), 2))
        print('np.max(np.abs(ap-ap_smooth))', round(np.max(np.abs(ap-ap_smooth)), 2))
        print('np.max(np.abs(y-y_smooth))', round(np.max(np.abs(y-y_smooth)), 2))


if __name__ == '__main__':
    reconstruct_smooth_features()

@mrgloom
Copy link

mrgloom commented May 30, 2019

Abouth the same test for griffin-lim algorithm, even at kernel_size 5 it already produce bad results.

Results: griffin_lim_smooth_mel_reconstruction.zip

import os
import shutil

from scipy import signal
import numpy as np

from utils.audio import AudioProcessor

if __name__ == "__main__":

    ap = AudioProcessor(
        num_mels = 80,
        num_freq =  1025,
        sample_rate = 22050,
        frame_length_ms = 50,
        frame_shift_ms = 12.5,
        preemphasis = 0.98,
        min_level_db = -100,
        ref_level_db = 20,
        power = 1.5,
        griffin_lim_iters = 60,
        signal_norm = True,
        symmetric_norm = False,
        max_norm = 1,
        clip_norm = True,
        mel_fmin = 0.0,
        mel_fmax = 8000.0,
        do_trim_silence = False
    )

    def smooth_2d(y, n):
        box = np.ones((n,n)) / (n*n)
        y_smooth = signal.convolve(y, box, mode='same')
        return y_smooth

    for kernel_size in [3,5,7,9,15,31,65]:
        for griffin_lim_iters in [30,60,90]:
            print('-' * 60)
            print('kernel_size', kernel_size)
            print('griffin_lim_iters', griffin_lim_iters)

            ap.griffin_lim_iters = griffin_lim_iters

            output_dir = 'griffin_lim_iters_'+str(griffin_lim_iters)+'_kernel_size_'+str(kernel_size)
            os.makedirs(output_dir)

            wav_filepath = './example_data/LJ001-0001.wav'

            wav = ap.load_wav(wav_filepath)
            mel = ap.melspectrogram(wav)
            shutil.copy(wav_filepath, os.path.join(output_dir, 'original.wav'))

            wav_reconstructed = ap.inv_mel_spectrogram(mel)
            ap.save_wav(wav_reconstructed, os.path.join(output_dir, 'original_reconstructed.wav'))

            mel_smooth = smooth_2d(mel, kernel_size)
            wav_smooth_reconstructed = ap.inv_mel_spectrogram(mel_smooth)
            ap.save_wav(wav_smooth_reconstructed, os.path.join(output_dir, 'smooth_reconstructed.wav'))

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
help wanted Extra attention is needed improvement a new feature
Projects
None yet
Development

No branches or pull requests

9 participants