# StyleTTS Demo (LJSpeech)


### Utils

In [36]:
# load packages
import yaml
from munch import Munch
import torch
import torchaudio
import librosa
from meldataset import mel_spectrogram, MAX_WAV_VALUE
import IPython.display as ipd
from text_utils import TextCleaner

from models import *
from utils import *

%matplotlib inline

In [2]:
device = 'cpu' #cuda:1' if torch.cuda.is_available() else 'cpu'

In [3]:
text_cleaner = TextCleaner()

In [33]:
def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask

def compute_style(audio_path):

    wave, sr = librosa.load(audio_path, sr=22050)
    audio, index = librosa.effects.trim(wave, top_db=30)
    if sr != 22050:
        audio = librosa.resample(audio, sr, 22050)
        
    wave_tensor = torch.from_numpy(wave).float()
    mel_tensor = mel_spectrogram(
        wave_tensor.unsqueeze(0), 2048, 80, 22050, 256, 1024, 0, None, False
    ).to(device)

    with torch.no_grad():
        ref = model.style_encoder(mel_tensor.unsqueeze(1))
        
    return ref

### Load models

In [6]:
import json
from BigVGAN.models import BigVGAN as Generator
from BigVGAN.env import AttrDict

def get_mel(x):
    return mel_spectrogram(x, h.n_fft, h.num_mels, h.sampling_rate, h.hop_size, h.win_size, h.fmin, h.fmax)

model_file = "/home/alexander/Projekte/StyleTTS/TTS/BigVGAN/Models/g_00350000"
config_file = "/home/alexander/Projekte/StyleTTS/TTS/BigVGAN/Models/config.json"
with open(config_file) as f:
    data = f.read()
json_config = json.loads(data)
h = AttrDict(json_config)

generator = Generator(h).to(device)

checkpoint_dict = torch.load(model_file, map_location=device)
generator.load_state_dict(checkpoint_dict['generator'])

generator.eval()
generator.remove_weight_norm()

Removing weight norm...


In [27]:
# load StyleTTS
model_path = "/home/alexander/Projekte/StyleTTS/TTS/Models/GameTTS/epoch_2nd_00006.pth"
model_config_path = "/home/alexander/Projekte/StyleTTS/TTS/Models/GameTTS/config.yml"

config = yaml.safe_load(open(model_config_path))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

model = build_model(Munch(config['model_params']), text_aligner, pitch_extractor)

params = torch.load(model_path, map_location='cpu')
params = params['net']
for key in model:
    if key in params:
        if not "discriminator" in key:
            print('%s loaded' % key)
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                    name = k[7:] # remove `module.`
                    new_state_dict[name] = v
            model[key].load_state_dict(new_state_dict)
            
            #model[key].load_state_dict(params[key])
_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

predictor loaded
decoder loaded
pitch_extractor loaded
text_encoder loaded
style_encoder loaded
text_aligner loaded


### Synthesize speech

In [46]:
import glob
import os
import random

wav_files = glob.glob("/home/alexander/Projekte/TTS_Data/en/WoW/troll/male/**/*.wav", recursive=True)

random.shuffle(wav_files)

ref_dicts = {}
num_refs = 0

for file_path in wav_files:
    filename = os.path.basename(file_path).replace(".wav", "")
    print(filename)
    ref_dicts[filename] = file_path
    num_refs += 1
    if num_refs == 2:
        break

trollmale_err_mustequipitem02
trollmale_err_noenergy02


In [44]:
# synthesize a text
text_de = '''zoː fiːl ɡɛlt haːbə ɪç nɔx niː bəkɔmən. ɛː, ɪç maɪ̯nə. haːp daŋk. ɪç hɔfə, eːɐ̯ ɪst diːɐ̯ fɔn nʊt͡sn̩.'''

text_en = '''aɪ kænt dɹɪŋk ɛnimɔː jɛt.'''

In [45]:
# tokenize
tokens = text_cleaner(text_en)
tokens.insert(0, 0)
tokens.append(0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

In [51]:
with torch.no_grad():
    input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
    m = length_to_mask(input_lengths).to(device)
    t_en = model.text_encoder(tokens, input_lengths, m)
        
    for key, audio_path in ref_dicts.items():
        
        ref = compute_style(audio_path)
        
        s = ref.squeeze(1)
        
        d = model.predictor.text_encoder(t_en, s, input_lengths, m)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)
        
        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        #style = s.expand(en.shape[0], en.shape[1], -1)

        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)

        out = model.decoder((t_en @ pred_aln_trg.unsqueeze(0).to(device)), 
                                F0_pred, N_pred, ref.squeeze(1))

        c = out.squeeze()
        y_g_hat = generator(c.unsqueeze(0))
        y_out = y_g_hat.squeeze()
        
        print('Synthesized: %s' % key)
        display(ipd.Audio(y_out.cpu().numpy(), rate=44100))
        print('Reference: %s' % key)
        display(ipd.Audio(audio_path, rate=22050))

Synthesized: trollmale_err_mustequipitem02


Reference: trollmale_err_mustequipitem02


Synthesized: trollmale_err_noenergy02


Reference: trollmale_err_noenergy02
