# StyleTTS Demo (LJSpeech)


### Utils

In [1]:
%cd ..

/data3/wangqiaolin/code/style_tts/StyleTTS


In [2]:
# load packages
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize

from models import *
from utils import *

%matplotlib inline

In [3]:
device = 'cpu'

In [15]:
_pad        = '_'

_punctuation_zh = '；：，。！？-“”《》、（）ＢＰ…—~.\·『』・ '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz'

_numbers = '1234567890'

_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"


# Export all symbols:
symbols = [_pad] + list(_punctuation_zh) +  list(_letters) + list(_numbers)

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print("KeyError", text)
        return indexes
    
textclenaer = TextCleaner()

In [5]:
import torchaudio
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(ref_dicts):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, index = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)

        with torch.no_grad():
            ref = model.style_encoder(mel_tensor.unsqueeze(1))
        reference_embeddings[key] = (ref.squeeze(1), audio)
    
    return reference_embeddings

### Load models

In [7]:
# load hifi-gan

import sys
sys.path.insert(0, "./Demo/hifi-gan")

import glob
import os
import argparse
import json
import torch
from scipy.io.wavfile import write
from attrdict import AttrDict
from vocoder import Generator
import librosa
import numpy as np
import torchaudio

h = None

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict

def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

cp_g = scan_checkpoint("Vocoder/", 'g_')

config_file = os.path.join(os.path.split(cp_g)[0], 'config.json')
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

device = torch.device(device)
generator = Generator(h).to(device)

state_dict_g = load_checkpoint(cp_g, device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()

Loading 'Vocoder/g_00750000'
Complete.
Removing weight norm...


In [8]:
# load StyleTTS
model_path = "./Models/Paimon/epoch_2nd_00098.pth"
model_config_path = "./Models/Paimon/config_paimon.yml"

config = yaml.safe_load(open(model_config_path))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

model = build_model(Munch(config['model_params']), text_aligner, pitch_extractor)

params = torch.load(model_path, map_location='cpu')
params = params['net']
for key in model:
    if key in params:
        if not "discriminator" in key:
            print('%s loaded' % key)
            model[key].load_state_dict(params[key])
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]



predictor loaded
decoder loaded
pitch_extractor loaded
text_encoder loaded
style_encoder loaded
text_aligner loaded


### Synthesize speech

In [9]:
# get first 3 training sample as references

train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
train_list, val_list = get_data_path_list(train_path, val_path)

ref_dicts = {}
for j in range(10):
    filename = train_list[j].split('|')[0]
    name = filename.split('/')[-1].replace('.wav', '')
    ref_dicts[name] = filename
    
reference_embeddings = compute_style(ref_dicts)

In [28]:
!pip install pypinyin

Looking in indexes: https://pypi.tuna.tsinghua.edu.cn/simple
Collecting pypinyin
  Using cached https://pypi.tuna.tsinghua.edu.cn/packages/00/fc/3e82bf38739a7b2c4f699245ce6c84ff254723c678c2cdc5d2ecbddf9afb/pypinyin-0.49.0-py2.py3-none-any.whl (1.4 MB)
Installing collected packages: pypinyin
Successfully installed pypinyin-0.49.0
[0m

In [41]:
from pypinyin import Style, pinyin
from pypinyin.style._utils import get_finals, get_initials


def chinese_cleaners1(text):
    from pypinyin import Style, pinyin

    phones = [phone[0] for phone in pinyin(text, style=Style.TONE3)]
    return ' '.join(phones)
  
  
def chinese_cleaners2(text):
  phones = [
      p
      for phone in pinyin(text, style=Style.TONE3)
      for p in [
          get_initials(phone[0], strict=True),
          get_finals(phone[0][:-1], strict=True) + phone[0][-1]
          if phone[0][-1].isdigit()
          else get_finals(phone[0], strict=True)
          if phone[0][-1].isalnum()
          else phone[0],
      ]
      # Remove the case of individual tones as a phoneme
      if len(p) != 0 and not p.isdigit()
  ]
  return phones
  # return phonemes
  
if __name__ == '__main__':
  res = chinese_cleaners2('这是语音测试！')
  print(res)
  res = chinese_cleaners1('语音合成是将人类的语言用人工的形式产生。')
  print(res)
  

['zh', 'e4', 'sh', 'i4', 'ü3', 'in1', 'c', 'e4', 'sh', 'i4', '！']
yu3 yin1 he2 cheng2 shi4 jiang1 ren2 lei4 de yu3 yan2 yong4 ren2 gong1 de xing2 shi4 chan3 sheng1 。
ðɪs ɪz ɐ klˈʌb tˈɛst fɔːɹ wˈʌn tɹˈeɪn.dʒˌiːdˌiːpˈiː


In [42]:
# synthesize a text
text = ''' yu3 yin1 he2 cheng2 shi4 jiang1 ren2 lei4 de yu3 yan2 yong4 ren2 gong1 de xing2 shi4 chan3 sheng1 。'''

In [43]:
# tokenize

tokens = textclenaer(text)
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

In [44]:
converted_samples = {}

with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s
        
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()

In [45]:
import IPython.display as ipd
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue

Synthesized: 100100005


Reference: 100100005


Synthesized: 100100007


Reference: 100100007


Synthesized: 100100012


Reference: 100100012


Synthesized: 100100022


Reference: 100100022


Synthesized: 100100024


Reference: 100100024


Synthesized: 100100025


Reference: 100100025


Synthesized: 100100029


Reference: 100100029


Synthesized: 100100031


Reference: 100100031


Synthesized: 100100033


Reference: 100100033


Synthesized: 100100038


Reference: 100100038
