# StyleTTS 2 Demo

### Interactive Settings

In [1]:
# @title Settings & Data Sources
# @markdown ***Language:***
# @markdown <br><small>This language will be used to convert your text into phonemes.</small>
# @markdown <br><small>This value must be one of the languages compatible with the phonemizer: https://pypi.org/project/phonemizer/</small>
# @markdown <br><small>For example, if you're using the default espeak phonemization: https://github.com/espeak-ng/espeak-ng/blob/master/docs/languages.md</small>
language = 'en-us' # @param {type:"string"}

# @markdown ***Multi-speaker mode:***
# @markdown <br><small>Tick the checkbox if you're using a multi-speaker model.</small>
is_multispeaker = False # @param {type:"boolean"}

# @markdown ***Number of diffusion steps:***
# @markdown <br><small>The basic value is 5.</small>
# @markdown <br><small>The more diffusion steps, the more diverse the output speech is at the expense of slow inference</small>
diffusion_steps = 5 # @param {type:"integer"}

# @markdown ***Embedding scale:***
# @markdown <br><small>This is the classifier-free guidance scale.</small>
# @markdown <br><small>The higher the scale, the more conditional the style is to the input text and hence more emotional.</small>
embedding_scale = 1 # @param {type:"integer"}

# @markdown ***Voice:***
# @markdown <small>(choose a voice)</small>
data_source = "ThaTi" # @param ["ThaTi", "BarEm"]

### Load Packages

In [None]:
%cd ..

# load packages
import torch
import time
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize
import ipywidgets as widgets
import IPython.display as ipd
import phonemizer

from models import *
from utils import *
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule
from text_utils import TextCleaner
textcleaner = TextCleaner()

%matplotlib inline

### Randomness and GPU/CPU Settings

In [3]:
# Randomness settings
torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

random.seed(0)
np.random.seed(0)

# Check GPU
device = 'cuda' if torch.cuda.is_available() else 'cpu'

### Functions

In [4]:
# Funkce pro spuštění konkrétní buňky podle jejího čísla
def run_cell(cell_num):
    ipd.display(ipd.Javascript(f'Jupyter.notebook.execute_cells([{cell_num}])'))

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(ref_dicts, model):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, _ = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)

        with torch.no_grad():
            ref = model.style_encoder(mel_tensor.unsqueeze(1))
        reference_embeddings[key] = (ref.squeeze(1), audio)
    return reference_embeddings

def fix_multispeaker(is_multispeaker, config):
    if not is_multispeaker and config['model_params']['multispeaker']:
        config['model_params']['multispeaker'] = False
    return config

def phonemize(text, phonemizer):
    text = text.strip()
    text = text.replace('"', '')
    ps = phonemizer.phonemize([text])
    return ps[0]

def inference(ps, noise, vocoder, diffusion_steps=5, embedding_scale=1):
    print(ps)
    ps = word_tokenize(ps)
    print(ps)
    ps = ' '.join(ps)
    print(ps)
    ps = ps.replace(' .', '.')
    ps = ps.replace(' ,', ',')
    ps = ps.replace(' ;', ';')
    ps = ps.replace(' :', ':')
    print(ps)
    
    tokens = textcleaner(ps)
    tokens.insert(0, 0)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
    
    with torch.no_grad():
        # Input token length
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device)
        text_mask = length_to_mask(input_lengths).to(tokens.device)

        t_en = model.text_encoder(tokens, input_lengths, text_mask)
        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2)

        s_pred = sampler(
            noise, 
            embedding=bert_dur[0].unsqueeze(0),
            num_steps=diffusion_steps,
            embedding_scale=embedding_scale
        ).squeeze(0)

        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_dur[-1] += 5

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device)
        if vocoder == "hifigan":
            asr_new = torch.zeros_like(en)
            asr_new[:, :, 0] = en[:, :, 0]
            asr_new[:, :, 1:] = en[:, :, 0:-1]
            en = asr_new
        
        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        asr = t_en @ pred_aln_trg.unsqueeze(0).to(device)
        if vocoder == "hifigan":
            asr_new = torch.zeros_like(asr)
            asr_new[:, :, 0] = asr[:, :, 0]
            asr_new[:, :, 1:] = asr[:, :, 0:-1]
            asr = asr_new
        
        out = model.decoder(
            t_en @ pred_aln_trg.unsqueeze(0).to(device), 
            F0_pred,
            N_pred,
            ref.squeeze().unsqueeze(0)
        )
        
    return out.squeeze().cpu().numpy()

to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80,
    n_fft=2048,
    win_length=1200,
    hop_length=300)
mean, std = -4, 4

### Voices Definition

In [5]:
voices = {
    'ThaTi': 'Models/ThaTi-5s8k_ft-LibriTTS_bs8.ml400/epochs_2nd_00043.pth'
}

### Load models

In [6]:
# load phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language=language, preserve_punctuation=True,  with_stress=True)

# config = yaml.safe_load(open("Models/ThaTi-5s8k_ft-LibriTTS_bs8.ml400/config.yml"))
config = yaml.safe_load(open("Models/LJS_orig/config.yml"))
config = fix_multispeaker(is_multispeaker, config)

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT_mlng.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

### Build Models

In [None]:
model = build_model(recursive_munch(config['model_params']), text_aligner, pitch_extractor, plbert)
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

# params_whole = torch.load("Models/ThaTi-5s8k_ft-LibriTTS_bs8.ml400/epoch_2nd_00049.pth", map_location='cpu')
params_whole = torch.load("Models/LJS_orig/epoch_2nd_00100.pth", map_location='cpu')
params = params_whole['net']

# Fix model
for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)
#             except:
#                 _load(params[key], model[key])
_ = [model[key].eval() for key in model]

# Init sampler
sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

### Synthesize speech

In [8]:
# @markdown ***Text to synthesize:***
# @markdown <br><small>Write a text to synthesize.</small>
text = '''StyleTTS 2 is a text to speech model that leverages style diffusion and adversarial training with large speech language models to achieve human-level text to speech synthesis.''' # @param {type:"string"}
# text = '''dˈuː juː hæv ˌɛni pˈeɪn?''' # @param {type:"string"}
text = 'A big canvas tent was the first thing to come within his vision.'
# text = 'ɐ bˈɪɡ kˈænvəs tˈɛnt wʌzðə fˈɜːst θˈɪŋ tə kˈʌm wɪðˌɪn hɪz vˈɪʒən.'

In [None]:
# import nltk
# nltk.download('punkt')

start = time.time()
noise = torch.randn(1,1,256).to(device)
ps = phonemize(text, global_phonemizer)
wav = inference(
    ps,
    noise,
    config['model_params']['decoder']['type'],
    diffusion_steps=diffusion_steps,
    embedding_scale=embedding_scale,
)
rtf = (time.time() - start) / (len(wav) / 24000)
print(f"RTF = {rtf:5f}")
display(ipd.Audio(wav, rate=24000))

In [None]:
# Create button
button = widgets.Button(description="Synthesize")

# Map button to function
button.on_click(lambda b: run_cell(9))  # Cell number to be launched

# Zobrazení tlačítka
ipd.display(button)