# StyleTTS 2 Demo (LJSpeech)


### Utils

#### Load Packages

In [1]:
import os
os.environ["PHONEMIZER_ESPEAK_LIBRARY"] = r"C:\Program Files\eSpeak NG\libespeak-ng.dll"  # <-- adjust if different

import torch # Deep Learning Framework
from torch import nn
import torch.nn.functional as F
torch.manual_seed(0) # Fixes starting point of random seed for torch
torch.backends.cudnn.benchmark = False # Fix convolution algorithm
torch.backends.cudnn.deterministic = True # Only use deterministic algorithms

import random # Python's built-in RNG
random.seed(0) # Fix random seed

import numpy as np # Numerical Computing Library
np.random.seed(0) # Fix random seed

import torchaudio # Loading/saving waveforms, resampling, transforms
import librosa # Python library for audio analysis
import soundfile as sf
from munch import Munch # Turns dictionaries into objects with attribute-style access
from nltk.tokenize import word_tokenize # Tokenizers divide strings into lists of substrings
import time # Used for timing operations
import yaml # Required for config.yml to load model hyperparameters and paths
import pprint
import os

%cd ..

from models import *
from utils import *
from text_utils import TextCleaner
textcleaner = TextCleaner() # Lowercasing & trimming, expanding numbers & symbols, handling punctuation, phoneme conversion, tokenization

import phonemizer # Splits words into phonemes (symbols that represent how words are pronounced)
global_phonemizer = phonemizer.backend.EspeakBackend(
    language='en-us',
    preserve_punctuation=True, # Keeps Punctuation such as , . ? !
    with_stress=True # Adds stress marks to vowels
)

device = 'cuda' if torch.cuda.is_available() else 'cpu'

# Tell matplotlib to, if used, render inline in the notebook outputs
%matplotlib inline

C:\Users\Yanis Wilbrand\PycharmProjects\StyleTTS2
177


#### Define Values

In [2]:
# Enum adversarial_audio_generation
# 0 = generate and save target
# 1 = generate and save ground truth
# 2 = generate ground truth and load target
# 3 = load ground truth and load target

adversarial_audio_generation = 1

interpolation_percentage = 0.5

text = ''' That is a banana. '''

#### Helper Functions

In [3]:
to_mel = torchaudio.transforms.MelSpectrogram(
    n_mels=80, # Number of Mel-frequency bins
    n_fft=2048, # Length of FFT
    win_length=1200, # Size of window in samples
    hop_length=300 # Step size between windows
)

mean, std = -4, 4 # Normalization parameters, roughly centers and scales mel values

def length_to_mask(lengths):
    mask = torch.arange(lengths.max()) # Creates a Vector [0,1,2,3,...,x], where x = biggest value in lengths
    mask = mask.unsqueeze(0) # Creates a Matrix [1,x] from Vector [x]
    mask = mask.expand(lengths.shape[0], -1) # Expands the matrix from [1,x] to [y,x], where y = number of elements in lengths
    mask = mask.type_as(lengths) # Assign mask the same type as lengths
    mask = torch.gt(mask+1, lengths.unsqueeze(1)) # gt = greater than, compares each value from lengths to a row of values in mask; unsqueeze = splits vector lengths into vectors of size 1
    return mask # returns a mask of shape (batch_size, max_length) where mask[i, j] = 1 if j < lengths[i] and mask[i, j] = 0 otherwise.

### Load models

In [4]:
config = yaml.safe_load(open("Models/LJSpeech/config.yml")) # YAML File with model settings and pretrained checkpoints (ASR, F0, PL-BERT)

# load pretrained ASR (Automatic Speech Recognition) model
ASR_config = config.get('ASR_config', False) # YAML config that describes the model’s structure
ASR_path = config.get('ASR_path', False) # Checkpoint File
text_aligner = load_ASR_models(ASR_path, ASR_config) # Load PyTorch model

# load pretrained F0 model (Extracts Pitch Features from Audio, How Pitch Changes over time)
F0_path = config.get('F0_path', False) # YAML config that describes the model’s structure
pitch_extractor = load_F0_models(F0_path)

# load BERT model (encodes input text with prosodic cues)
# BERT = Bidirectional Encoder Representations from Transformers
# Represent text as a sequence of vectors
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False) # YAML config that describes the model’s structure
plbert = load_plbert(BERT_path)

  from .autonotebook import tqdm as notebook_tqdm


In [5]:
model = build_model(
    recursive_munch(config['model_params']), # Allows attribute-style access to keys of model_params,
    text_aligner, # Automatic Speech Recognition model
    pitch_extractor, # F0 model
    plbert # BERT model
)

_ = [model[key].eval() for key in model]
_ = [model[key].to(device) for key in model]

  WeightNorm.apply(module, name, dim)


In [6]:
params_whole = torch.load("Models/LJSpeech/epoch_2nd_00100.pth", map_location='cpu')
params = params_whole['net']

In [7]:
for key in model:
    if key in params:
        print('%s loaded' % key)
        try:
            model[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            model[key].load_state_dict(new_state_dict, strict=False)
#             except:
#                 _load(params[key], model[key])
_ = [model[key].eval() for key in model]

bert loaded
bert_encoder loaded
predictor loaded
decoder loaded
text_encoder loaded
predictor_encoder loaded
style_encoder loaded
diffusion loaded
text_aligner loaded
pitch_extractor loaded
mpd loaded
msd loaded
wd loaded


In [8]:
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule

In [9]:
sampler = DiffusionSampler(
    model.diffusion.diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

### Synthesize speech

In [10]:
def inference(text, noise, diffusion_steps=5, embedding_scale=1):
    text = text.strip() # Removes whitespaces from beginning and end of string
    text = text.replace('"', '') # removes " to prevent unpredictable behavior

    ps = global_phonemizer.phonemize([text]) # text -> list of phoneme
    print("1. List of phonemes: ", ps)
    ps = word_tokenize(ps[0]) # Split into individual tokens
    print("2. String of phonemes: ", ps)
    ps = ' '.join(ps) # Join tokens together, split by a empty space
    print("3. Final string of phonemes: ", ps)

    tokens = textcleaner(ps) # Look up numeric ID per phoneme
    print("4. ID of phonemes: ", tokens)
    tokens.insert(0, 0) # Insert leading 0 to mark start
    print("5. ID with leading 0: ", tokens)
    tokens = torch.LongTensor(tokens).to(device).unsqueeze(0) # Converts numeric ID to PyTorch Tensor
    print("6. Pytorch Tensor Dimension: ", tokens.shape)

    with torch.no_grad(): # No training, so no gradient-descent / backpropagation
        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(tokens.device) # Number of phoneme / Length of tokens, shape[-1] = last element in list/array
        text_mask = length_to_mask(input_lengths).to(tokens.device) # Creates a bitmask based on number of phonemes
        print("Text Mask: ", text_mask)

        t_en = model.text_encoder(tokens, input_lengths, text_mask) # Creates text encoder (phoneme -> feature vectors)

        if adversarial_audio_generation == 0:
            np.save("latents/h_text_target.npy", t_en.detach().cpu().numpy())

        elif adversarial_audio_generation == 1:
            np.save("latents/h_text_ground_truth.npy", t_en.detach().cpu().numpy())

        elif adversarial_audio_generation == 2:
            np.save("latents/h_text_ground_truth.npy", t_en.detach().cpu().numpy())
            t_en_target = torch.tensor(np.load("latents/h_text_target.npy"), dtype=t_en.dtype, device=device)

            print("t_en shape:", t_en.shape)
            print("t_en_target shape:", t_en_target.shape)

            t_en_target = F.interpolate(t_en_target, size=t_en.size(-1), mode='linear', align_corners=False)

            t_en = (1 - interpolation_percentage) * t_en + interpolation_percentage * t_en_target
            np.save("latents/h_text_interpolated.npy", t_en.detach().cpu().numpy())

        elif adversarial_audio_generation == 3:
            t_en = np.load("latents/h_text_ground_truth.npy")
            t_en_gt = torch.tensor(t_en, device=device)
            t_en_target = torch.tensor(np.load("latents/h_text_target.npy"), dtype=t_en_gt.dtype, device=device)

            print("t_en_gt shape:", t_en_gt.shape)
            print("t_en_target shape:", t_en_target.shape)

            t_en_target = F.interpolate(t_en_target, size=t_en_gt.size(-1), mode='linear', align_corners=False)

            t_en = (1 - interpolation_percentage) * t_en_gt + interpolation_percentage * t_en_target
            np.save("latents/h_text_interpolated.npy", t_en.detach().cpu().numpy())


        bert_dur = model.bert(tokens, attention_mask=(~text_mask).int())
        d_en = model.bert_encoder(bert_dur).transpose(-1, -2)

        s_pred = sampler(noise,
              embedding=bert_dur[0].unsqueeze(0), num_steps=diffusion_steps,
              embedding_scale=embedding_scale).squeeze(0)

        s = s_pred[:, 128:]
        ref = s_pred[:, :128]

        d = model.predictor.text_encoder(d_en, s, input_lengths, text_mask)

        x, _ = model.predictor.lstm(d)
        duration = model.predictor.duration_proj(x)
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1)

        pred_dur[-1] += 5

        pred_aln_trg = torch.zeros(input_lengths, int(pred_dur.sum().data))
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + int(pred_dur[i].data)] = 1
            c_frame += int(pred_dur[i].data)

        h_aligned = t_en @ pred_aln_trg.unsqueeze(0).to(device)  # (B, D_text, T_frames)

        jitter_strength = 0.3  # try 0.1–0.5
        pred_dur = torch.round(pred_dur.float() + torch.randn_like(pred_dur.float()) * jitter_strength).clamp(min=1).long()


        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        F0_pred, N_pred = model.predictor.F0Ntrain(en, s)
        
        out = model.decoder(
            h_aligned,
            F0_pred,
            N_pred,
            ref.squeeze().unsqueeze(0)
        )
        
    return out.squeeze().cpu().numpy()

#### Basic synthesis (5 diffusion steps)

In [13]:
noise = torch.randn(1,1,256).to(device)
wav = inference(text, noise, diffusion_steps=5, embedding_scale=1)

import IPython.display as ipd
display(ipd.Audio(wav, rate=24000))

if adversarial_audio_generation == 0:
    name = "target.wav"
elif adversarial_audio_generation == 1:
    name = "ground_truth.wav"
else:
    name = "ground_truth_interpolated.wav"
name = "audio/" + name

sf.write(name, wav, samplerate=24000)

1. List of phonemes:  ['ðæt ɪz ɐ bɐnˈænə. ']
2. String of phonemes:  ['ðæt', 'ɪz', 'ɐ', 'bɐnˈænə', '.']
3. Final string of phonemes:  ðæt ɪz ɐ bɐnˈænə .
4. ID of phonemes:  [81, 72, 62, 16, 102, 68, 16, 70, 16, 44, 70, 56, 156, 72, 56, 83, 16, 4]
5. ID with leading 0:  [0, 81, 72, 62, 16, 102, 68, 16, 70, 16, 44, 70, 56, 156, 72, 56, 83, 16, 4]
6. Pytorch Tensor Dimension:  torch.Size([1, 19])


Traceback (most recent call last):
  File "C:\Program Files\JetBrains\PyCharm 2025.2.4\plugins\python-ce\helpers\pydev\_pydevd_bundle\pydevd_comm.py", line 736, in make_thread_stack_str
    append('file="%s" line="%s">' % (make_valid_xml_value(my_file), lineno))
  File "C:\Program Files\JetBrains\PyCharm 2025.2.4\plugins\python-ce\helpers\pydev\_pydevd_bundle\pydevd_xml.py", line 36, in make_valid_xml_value
    return s.replace("&", "&amp;").replace('<', '&lt;').replace('>', '&gt;').replace('"', '&quot;')
AttributeError: 'tuple' object has no attribute 'replace'


KeyboardInterrupt: 

#### With higher diffusion steps (more diverse)
Since the sampler is ancestral, the higher the stpes, the more diverse the samples are, with the cost of slower synthesis speed.

In [12]:
start = time.time()
noise = torch.randn(1,1,256).to(device)
wav = inference(text, noise, diffusion_steps=10, embedding_scale=1)
rtf = (time.time() - start) / (len(wav) / 24000)
print(f"RTF = {rtf:5f}")
import IPython.display as ipd
display(ipd.Audio(wav, rate=24000))

1. List of phonemes:  ['ðæt ɪz ɐ bɐnˈænə. ']
2. String of phonemes:  ['ðæt', 'ɪz', 'ɐ', 'bɐnˈænə', '.']
3. Final string of phonemes:  ðæt ɪz ɐ bɐnˈænə .
4. ID of phonemes:  [81, 72, 62, 16, 102, 68, 16, 70, 16, 44, 70, 56, 156, 72, 56, 83, 16, 4]
5. ID with leading 0:  [0, 81, 72, 62, 16, 102, 68, 16, 70, 16, 44, 70, 56, 156, 72, 56, 83, 16, 4]
6. Pytorch Tensor Dimension:  torch.Size([1, 19])
RTF = 0.088002
