# StyleTTS Demo (LJSpeech)


### Utils

In [1]:
%cd /home2/giangnth/multi-speaker/StyleTTS

/home2/giangnth/multi-speaker/StyleTTS


In [2]:
# load packages
import random
import yaml
from munch import Munch
import numpy as np
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio
import librosa
from nltk.tokenize import word_tokenize

from models import *
from utils import *

%matplotlib inline

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [4]:
_pad = "$"
_punctuation = ';:,.!?¡¿—…"«»“”-()# '
_letters = 'ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789'
_letters_ipa = "ɑɐɒæɓʙβɔɕçɗɖðʤəɘɚɛɜɝɞɟʄɡɠɢʛɦɧħɥʜɨɪʝɭɬɫɮʟɱɯɰŋɳɲɴøɵɸθœɶʘɹɺɾɻʀʁɽʂʃʈʧʉʊʋⱱʌɣɤʍχʎʏʑʐʒʔʡʕʢǀǁǂǃˈˌːˑʼʴʰʱʲʷˠˤ˞↓↑→↗↘'̩'ᵻ"


# Export all symbols:
symbols = [_pad] + list(_punctuation) + list(_letters) + list(_letters_ipa)

dicts = {}
for i in range(len((symbols))):
    dicts[symbols[i]] = i

class TextCleaner:
    def __init__(self, dummy=None):
        self.word_index_dictionary = dicts
    def __call__(self, text):
        indexes = []
        for char in text:
            try:
                indexes.append(self.word_index_dictionary[char])
            except KeyError:
                print(char)
        return indexes

textclenaer = TextCleaner()

In [5]:
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, n_fft=2048, win_length=1200, hop_length=300)
mean, std = -4, 4

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def preprocess(wave):
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = to_mel(wave_tensor)
    mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
    return mel_tensor

def compute_style(ref_dicts):
    reference_embeddings = {}
    for key, path in ref_dicts.items():
        wave, sr = librosa.load(path, sr=24000)
        audio, index = librosa.effects.trim(wave, top_db=30)
        if sr != 24000:
            audio = librosa.resample(audio, sr, 24000)
        mel_tensor = preprocess(audio).to(device)

        with torch.no_grad():
            ref = model.style_encoder(mel_tensor.unsqueeze(1))
        reference_embeddings[key] = (ref.squeeze(1), audio)
    
    return reference_embeddings

### Load models

In [6]:
# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)

In [7]:
# load hifi-gan

import sys
sys.path.insert(0, "./Demo/hifi-gan")

import glob
import os
import argparse
import json
import torch
from scipy.io.wavfile import write
from attrdict import AttrDict
from vocoder import Generator
import librosa
import numpy as np
import torchaudio

h = None

def load_checkpoint(filepath, device):
    assert os.path.isfile(filepath)
    print("Loading '{}'".format(filepath))
    checkpoint_dict = torch.load(filepath, map_location=device)
    print("Complete.")
    return checkpoint_dict

def scan_checkpoint(cp_dir, prefix):
    pattern = os.path.join(cp_dir, prefix + '*')
    cp_list = glob.glob(pattern)
    if len(cp_list) == 0:
        return ''
    return sorted(cp_list)[-1]

cp_g = scan_checkpoint("Vocoder/", 'g_')
print(cp_g)
config_file = os.path.join(os.path.split(cp_g)[0], 'config.json')
with open("./Demo/hifi-gan/Vocoder/config.json") as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

device = torch.device(device)
generator = Generator(h).to(device)

state_dict_g = load_checkpoint("./Demo/hifi-gan/Vocoder/g_00750000", device)
generator.load_state_dict(state_dict_g['generator'])
generator.eval()
generator.remove_weight_norm()




Loading './Demo/hifi-gan/Vocoder/g_00750000'
Complete.
Removing weight norm...


In [8]:
# load StyleTTS
model_path = "./Models/phuthang/epoch_2nd_00080.pth"
model_config_path = "./Models/phuthang/config.yml"

config = yaml.safe_load(open(model_config_path))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

model = build_model(Munch(config['model_params']), text_aligner, pitch_extractor)

params = torch.load(model_path, map_location='cpu')
params = params['net']
for key in model:
    if key in params:
        if not "discriminator" in key:
            print('%s loaded' % key)
            model[key].load_state_dict(params[key])
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]



predictor loaded
decoder loaded
pitch_extractor loaded
text_encoder loaded
style_encoder loaded
text_aligner loaded


### Synthesize speech

In [10]:
# get first 3 training sample as references

train_path = config.get('train_data', None)
val_path = config.get('val_data', None)
train_list, val_list = get_data_path_list(train_path, val_path)

ref_dicts = {}
# for j in range(1):
#     filename = train_list[j].split('|')[0]
#     name = filename.split('/')[-1].replace('.wav', '')
    # ref_dicts["PT"] = filename
ref_dicts["PT"] = "/home2/giangnth/multi-speaker/phuthang/wav_normalized_3db_24k/0050013_01.wav"
ref_dicts["LJ"] = "/home2/giangnth/multi-speaker/StyleTTS1/LJSpeech-1.1/wavs/LJ001-0009.wav"
ref_dicts["LT"]= "/home2/giangnth/multi-speaker/lantrinh/wav_origin/wav_origin/0004.wav"
ref_dicts["LT2"] = "/home2/giangnth/multi-speaker/lantrinh/trimmed_wav_processed/0031.wav"
ref_dicts["TTH"] = "/home2/giangnth/multi-speaker/bdmk_23_16.wav"
ref_dicts["1"] = "/home2/giangnth/multi-speaker/data-vlsp_denoised/VLSP-2023-EMO/train-set/000065.wav"
ref_dicts["008346"] = "/home2/giangnth/multi-speaker/data-vlsp_denoised/VLSP-2023-EMO/train-set/008346.wav"
ref_dicts["SGT2650"] = "/home2/giangnth/multi-speaker/data-vlsp_denoised/VLSP-2023-EMO/train-set/005286.wav"
reference_embeddings = compute_style(ref_dicts)

In [12]:
# synthesize a text
# ps = "kak NM9j-m9Xw tsOXNm so kaw tsuNm-biJ mot mEt tam va tsi tseJ JaXw haj den ba saXN-ti-mEt # ko nOj ###"
# tone = "666 2222$4444 111111 11 111 11111$222 888 666 555 22 333 1111 1111 111 555 11 1111$11$666 $ 11 555 $$$"
ps = "tM9N-tM de dam-baw zik-vu tsO kak xEXk haN # viet thEw da bo-suNm them h9n tsin xoNm xoNm Gi zuNm-lM9N baXN-thoNm kuok-te doNm-th9j kip-th9j diew-tsiJ lMw-lM9N kwpa kak hM9N d9Xt-lien va kap bien xak JM i 9Xj # aj bi zi # mot xoNm mot hM9N hoNm-koNm ###"
tone = "1111$77 33 333$333 888$77 111 666 6666 222 $ 8888 1111 44 33$1111 1111 111 5555 1111 1111 11 1111$7777 1111$11111 6666$55 2222$2222 888$2222 2222$3333 111$7777 1111 666 5555 6666$2222 22 666 3333 666 11 1 111 $ 11 11 11 $ 888 1111 888 5555 2222$1111 $$$"
# ps = "doj-v9j NM9j tsE # haN Nan noj-zuNm zaj-tsi h9Xp-z9Xn hoj-tu d9Xj-du tsOXNm thiet-bi JO GOn Ep-pe-te p-l9Xj bokp-s koNm ###"
# tone = "555$555 2222 333 $ 222 222 777$1111 333$555 6666$4444 777$77 2222$33 111111 66666$77 33 777 66$11$11 1$1111 6666$1 7777 $$$"
# ps ="xuj mot thuNm taw kOn bokp-h9j lEXJ # do aw za sE # Jiew NM9j di ts9 thaXj JaXw JaXt lM9 # bO bOXkp ###"
# tone="111 888 22222 555 222 6666$111 7777 $ 33 22 11 11 $ 2222 2222 11 777 11111 1111 8888 777 $ 33 88888 $$$"
#Khui một thùng táo
#Các người mẫu trong sô cao trung bình một mét tám và chỉ chênh nhau hai đến ba xăng-ti-mét, cô nói.

In [13]:
# tokenize
# ps = global_phonemizer.phonemize([text])
# ps = word_tokenize(ps[0])
# ps = ' '.join(ps)
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens.append(0)

tones = textclenaer(tone)
tones.insert(0, 0)
tones.append(0)

tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
tones = torch.LongTensor(tones).to(device).unsqueeze(0)

In [14]:
converted_samples = {}
    
with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, tones, input_lengths, m)
        
    for key, (ref, _) in reference_embeddings.items():
        
        s = ref.squeeze(1)
        style = s
        d = model.predictor.text_encoder(t_en, style, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze().unsqueeze(0))


        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze().cpu().numpy()

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        converted_samples[key] = y_out.cpu().numpy()

In [15]:
import IPython.display as ipd
import soundfile as sf
def inference(wave_path, phone, tone):
    name = os.path.basename(wave_path)
    #phone, tone -> sequence
    tokens = textclenaer(phone)
    tokens.insert(0, 0)
    tokens.append(0)

    tones = textclenaer(tone)
    tones.insert(0, 0)
    tones.append(0)

    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)
    tones = torch.LongTensor(tones).to(device).unsqueeze(0)

    #infer
    converted_samples = {}
        
    with torch.no_grad():
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        m = length_to_mask(input_lengths).to(device)
        t_en = model.text_encoder(tokens, tones, input_lengths, m)
            
        for key, (ref, _) in reference_embeddings.items():
            
            s = ref.squeeze(1)
            style = s
            d = model.predictor.text_encoder(t_en, style, input_lengths, m)

            x, _ = model.predictor.lstm(d)
            duration = model.predictor.duration_proj(x)
            pred_dur = torch.round(duration.squeeze()).clamp(min=1)
            
            pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
            c_frame = 0
            for i in range(pred_aln_trg.size(0)):
                pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
                c_frame += int(pred_dur[i].data)

            # encode prosody
            en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
            style = s.expand(en.shape[0], en.shape[1], -1)

            F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

            out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                    F0_pred, N_pred, ref.squeeze().unsqueeze(0))


            c = out.squeeze()
            y_g_hat = generator(c.unsqueeze(0))
            y_out = y_g_hat.squeeze().cpu().numpy()

            c = out.squeeze()
            y_g_hat = generator(c.unsqueeze(0))
            y_out = y_g_hat.squeeze()
            
            converted_samples[key] = y_out.cpu().numpy()
    
    for key, wave in converted_samples.items():
        sf.write(os.path.join("/home2/giangnth/multi-speaker/StyleTTS/Demo/aug_data", key + "_" + name), wave, 24000)
        new_item = str(os.path.join("/home2/giangnth/multi-speaker/StyleTTS/Demo/aug_data", key + "_" + name)) + "|" + phone + "|" + tone + "|" + str(0) + "\n"
        
        file = open("metadata_aug_phuthang.txt", "a")
        file.write(new_item)

In [16]:
# inference("/home2/giangnth/multi-speaker/phuthang/wav_normalized_3db_24k/0052926_01.wav", " lam-aXn v9j a-li-ba-ba zoj tson-thwpe ###", " 222$111 555 1$11$11$11 222 5555$55555 $$$")
#/home2/giangnth/multi-speaker/phuthang/wav_normalized_3db_24k/0052926_01.wav| lam-aXn v9j a-li-ba-ba zoj tson-thwpe ###| 222$111 555 1$11$11$11 222 5555$55555 $$$|0

In [17]:
import IPython.display as ipd
import soundfile as sf
for key, wave in converted_samples.items():
    print('Synthesized: %s' % key)
    display(ipd.Audio(wave, rate=24000))
    sf.write(os.path.join("/home2/giangnth/multi-speaker/StyleTTS/Demo/test", key+".wav"), wave, 24000)
    try:
        print('Reference: %s' % key)
        display(ipd.Audio(reference_embeddings[key][-1], rate=24000))
    except:
        continue

Synthesized: PT


Reference: PT


Synthesized: LJ


Reference: LJ


Synthesized: LT


Reference: LT


Synthesized: LT2


Reference: LT2


Synthesized: TTH


Reference: TTH


Synthesized: 1


Reference: 1


Synthesized: 008346


Reference: 008346


Synthesized: SGT2650


Reference: SGT2650


In [None]:
with open("/home2/giangnth/multi-speaker/StyleTTS/Demo/aug.txt", "r", encoding="utf-8") as f:
    for line in f:
        cols = line.split("|")
        wave_path = cols[0]
        phone = cols[1]
        tone = cols[2]
        inference(wave_path, phone, tone)
        