<a href="https://colab.research.google.com/github/joaorafaelm/notebooks/blob/master/forward_tts_transformer_and_wavernn_vocoder.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Transformer TTS: A Text-to-Speech Transformer in TensorFlow 2
## Audio synthesis with Forward Transformer TTS and WaveRNN Vocoder
### Forward Model

In [1]:
#@title setup
%%capture

# Clone the Transformer TTS and WaveRNN repos
!git clone https://github.com/as-ideas/TransformerTTS.git
!cd TransformerTTS && git checkout 1c1cb03 && cd ..
!git clone https://github.com/fatchord/WaveRNN

# Install requirements
!apt-get install -y espeak
!pip install -r TransformerTTS/requirements.txt

# Download the transformer pre-trained weights
! wget https://public-asai-dl-models.s3.eu-central-1.amazonaws.com/TransformerTTS/ljspeech_wavernn_forward_transformer.zip
! unzip -o ljspeech_wavernn_forward_transformer.zip

# Unzip the wave pretrained model
!unzip -o WaveRNN/pretrained/ljspeech.wavernn.mol.800k.zip -d WaveRNN/pretrained/

# Set up the paths
from pathlib import Path
WaveRNN_path = 'WaveRNN/'
TTS_path = 'TransformerTTS/'
config_path = Path('ljspeech_wavernn_forward_transformer/wavernn')

# wavernn model
import sys
sys.path.append(WaveRNN_path)

from utils.dsp import hp
from models.fatchord_version import WaveRNN
import torch
import numpy as np
WaveRNN_path = Path(WaveRNN_path)

# Load pretrained model
try:
    hp.configure(WaveRNN_path / 'hparams.py')  # Load hparams from file
except:
    # cant reconfigure, bypass to avoid restart runtime
    pass

if torch.cuda.is_available():
    device = torch.device('cuda')
else:
    device = torch.device('cpu')
wave_model = WaveRNN(rnn_dims=hp.voc_rnn_dims,
                fc_dims=hp.voc_fc_dims,
                bits=hp.bits,
                pad=hp.voc_pad,
                upsample_factors=hp.voc_upsample_factors,
                feat_dims=hp.num_mels,
                compute_dims=hp.voc_compute_dims,
                res_out_dims=hp.voc_res_out_dims,
                res_blocks=hp.voc_res_blocks,
                hop_length=hp.hop_length,
                sample_rate=hp.sample_rate,
                mode=hp.voc_mode).to(device)

wave_model.load(str(WaveRNN_path / 'pretrained/latest_weights.pyt'))

# Ignore some TF warnings
import tensorflow as tf
tf.get_logger().setLevel('ERROR')

# fix deprecated module on librosa
import soundfile as sf
import librosa

class output:
    write_wav = lambda path, data, sr: sf.write(path, data, samplerate=sr, subtype='PCM_24')

librosa.output = output

# Generate sample with pre-trained WaveRNN vocoder
hp_data = hp
def generate(mel, file_name="sample.wav", batch_pred=False, hp=hp_data):
    _ = wave_model.generate(mel.clip(0,1)[np.newaxis,:,:], file_name, batch_pred, 10_000, hp.voc_overlap, hp.mu_law)

    # Load wav file
    ipd.display(ipd.Audio(file_name))


# ljspeech_wavernn_forward_model
sys.path.remove('WaveRNN/')
sys.modules.pop('utils')
sys.path.append(TTS_path)

# Load pretrained models
from utils.config_manager import ConfigManager
from utils.audio import Audio

import IPython.display as ipd

config_loader = ConfigManager(str(config_path), model_kind='forward')
audio = Audio(config_loader.config)
model = config_loader.load_model(str(config_path / 'forward_weights/ckpt-133'))

In [None]:
#@title try it out
sentence = 'Audio synthesis with Forward Transformer TTS and WaveRNN Vocoder' #@param {type:"string"}
speed_regulator = 1.1 #@param {type:"slider", min:0, max:2, step:0.1}
batch_pred = False #@param {type:"boolean"}

out_normal = model.predict(sentence, speed_regulator=speed_regulator)

# Convert spectrogram to wav (with griffin lim)
wav = audio.reconstruct_waveform(out_normal['mel'].numpy().T)
print("ljspeech_wavernn_forward_model", ipd.display(ipd.Audio(wav, rate=config_loader.config['sampling_rate'])))

# Normalize for WaveRNN
mel = (out_normal['mel'].numpy().T+4.)/8.
print("after wavernn vocoder")
generate(mel, batch_pred=batch_pred)

ljspeech_wavernn_forward_model None
after wavernn vocoder
| █░░░░░░░░░░░░░░░ 4800/109725 | Batch Size: 1 | Gen Rate: 0.8kHz | 