# Export StyleTTS2 Model

See the [Colab Notebook](https://colab.research.google.com/github/yl4579/StyleTTS2/blob/main/Colab/StyleTTS2_Demo_LibriTTS.ipynb).

Currently, the official implementation of StyleTTS2 does not support ONNX export. See the [discussion](https://github.com/yl4579/StyleTTS2/issues/117).

## Install dependencies

In [None]:
!apt-get install espeak-ng --yes

In [None]:
!pip install SoundFile pydub pyyaml librosa nltk matplotlib phonemizer einops einops-exts tqdm typing-extensions

In [None]:
!apt-get install build-essential libssl-dev libffi-dev python3-dev --yes

In [None]:
!sudo apt install libpython3.10-dev --yes

In [None]:
!pip install git+https://github.com/resemble-ai/monotonic_align.git

In [None]:
!pip install transformers

In [None]:
!pip install onnx onnxscript onnxruntime

## Set up demo assets

In [None]:
!git clone https://huggingface.co/yl4579/StyleTTS2-LibriTTS

In [None]:
!mv StyleTTS2-LibriTTS/Models .

In [None]:
!mv StyleTTS2-LibriTTS/reference_audio.zip .

In [None]:
!unzip reference_audio.zip

In [None]:
!mv reference_audio Demo/reference_audio

In [None]:
import nltk
nltk.download('punkt')

## Import packages

In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
# %matplotlib inline

In [None]:
import torch

torch.__version__

In [None]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'
device

In [None]:
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio

torch.manual_seed(0)
torch.backends.cudnn.benchmark = False
torch.backends.cudnn.deterministic = True

import random
random.seed(0)

import numpy as np
np.random.seed(0)

# load packages
import yaml
import librosa
from nltk.tokenize import word_tokenize
import IPython.display as ipd

from models import load_ASR_models, load_F0_models
from text_utils import TextCleaner
textclenaer = TextCleaner()


def load_audio(path):
    wave, sr = librosa.load(path, sr=24000)
    audio, index = librosa.effects.trim(wave, top_db=30)
    if sr != 24000:
        audio = librosa.resample(audio, sr, 24000)
    return audio

In [None]:
from nltk.tokenize import word_tokenize

from text_utils import TextCleaner
textclenaer = TextCleaner()


# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)


# start with out-of-the-box implementations in C# of word tokenizer and phenomizer
# check how it differs with results from python implementations
# if there are differences test performance on ONNX model
# if results are good, we can keep the defailt C# implementations, otherwise we need to implement our own functions

# rewrite text cleaner from Python function into C#

## Set up StyleTTS2 model

In [None]:
# load phonemizer
import phonemizer
global_phonemizer = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True,  with_stress=True)

config = yaml.safe_load(open("Models/LibriTTS/config.yml"))

# load pretrained ASR model
ASR_config = config.get('ASR_config', False)
ASR_path = config.get('ASR_path', False)
text_aligner = load_ASR_models(ASR_path, ASR_config)

# load pretrained F0 model
F0_path = config.get('F0_path', False)
pitch_extractor = load_F0_models(F0_path)

# load BERT model
from Utils.PLBERT.util import load_plbert
BERT_path = config.get('PLBERT_dir', False)
plbert = load_plbert(BERT_path)

In [None]:
from models import (
    TextEncoder, ProsodyPredictor, StyleEncoder, StyleTransformer1d, 
    AudioDiffusionConditional, KDiffusion, LogNormalDistribution,
    MultiPeriodDiscriminator, MultiResSpecDiscriminator, WavLMDiscriminator)

def build_model(args, text_aligner, pitch_extractor, bert):
    assert args["decoder"]["type"] in ['istftnet', 'hifigan'], 'Decoder type unknown'
    
    if args["decoder"]["type"] == "istftnet":
        from Modules.istftnet import Decoder
        decoder = Decoder(dim_in=args["hidden_dim"], style_dim=args["style_dim"], dim_out=args["n_mels"],
                resblock_kernel_sizes = args["decoder"]["resblock_kernel_sizes"],
                upsample_rates = args["decoder"]["upsample_rates"],
                upsample_initial_channel=args["decoder"]["upsample_initial_channel"],
                resblock_dilation_sizes=args["decoder"]["resblock_dilation_sizes"],
                upsample_kernel_sizes=args["decoder"]["upsample_kernel_sizes"], 
                gen_istft_n_fft=args["decoder"]["gen_istft_n_fft"], gen_istft_hop_size=args["decoder"]["gen_istft_hop_size"]) 
    else:
        from Modules.hifigan import Decoder
        decoder = Decoder(dim_in=args["hidden_dim"], style_dim=args["style_dim"], dim_out=args["n_mels"],
                resblock_kernel_sizes = args["decoder"]["resblock_kernel_sizes"],
                upsample_rates = args["decoder"]["upsample_rates"],
                upsample_initial_channel=args["decoder"]["upsample_initial_channel"],
                resblock_dilation_sizes=args["decoder"]["resblock_dilation_sizes"],
                upsample_kernel_sizes=args["decoder"]["upsample_kernel_sizes"])
        
    text_encoder = TextEncoder(channels=args["hidden_dim"], kernel_size=5, depth=args["n_layer"], n_symbols=args["n_token"])
    
    predictor = ProsodyPredictor(
        style_dim=args["style_dim"], 
        d_hid=args["hidden_dim"], nlayers=args["n_layer"], 
        max_dur=args["max_dur"], dropout=args["dropout"])
    
    style_encoder = StyleEncoder(
        dim_in=args["dim_in"], style_dim=args["style_dim"], max_conv_dim=args["hidden_dim"]) # acoustic style encoder
    predictor_encoder = StyleEncoder(
        dim_in=args["dim_in"], style_dim=args["style_dim"], max_conv_dim=args["hidden_dim"]) # prosodic style encoder
        
    # define diffusion model
    if args["multispeaker"]:
        transformer = StyleTransformer1d(channels=args["style_dim"]*2, 
                                    context_embedding_features=bert.config.hidden_size,
                                    context_features=args["style_dim"]*2, 
                                    **args["diffusion"]["transformer"])
    else:
        transformer = Transformer1d(channels=args["style_dim"]*2, 
                                    context_embedding_features=bert.config.hidden_size,
                                    **args["diffusion"]["transformer"])
    
    diffusion = AudioDiffusionConditional(
        in_channels=1,
        embedding_max_length=bert.config.max_position_embeddings,
        embedding_features=bert.config.hidden_size,
        embedding_mask_proba=args["diffusion"]["embedding_mask_proba"], # Conditional dropout of batch elements,
        channels=args["style_dim"]*2,
        context_features=args["style_dim"]*2,
    )
    
    diffusion.diffusion = KDiffusion(
        net=diffusion.unet,
        sigma_distribution=LogNormalDistribution(mean = args["diffusion"]["dist"]["mean"], std = args["diffusion"]["dist"]["std"]),
        sigma_data=args["diffusion"]["dist"]["sigma_data"], # a placeholder, will be changed dynamically when start training diffusion model
        dynamic_threshold=0.0 
    )
    diffusion.diffusion.net = transformer
    diffusion.unet = transformer

    
    nets = dict(
        bert=bert,
        bert_encoder=nn.Linear(bert.config.hidden_size, args["hidden_dim"]),

        predictor=predictor,
        decoder=decoder,
        text_encoder=text_encoder,

        predictor_encoder=predictor_encoder,
        style_encoder=style_encoder,
        diffusion=diffusion,

        text_aligner = text_aligner,
        pitch_extractor=pitch_extractor,

        mpd = MultiPeriodDiscriminator(),
        msd = MultiResSpecDiscriminator(),
    
        # slm discriminator head
        wd = WavLMDiscriminator(args["slm"]["hidden"], args["slm"]["nlayers"], args["slm"]["initial_channel"]),
       )
    
    return nets

In [None]:
model_params = config['model_params']
# model_params = recursive_munch(config['model_params'])
nets = build_model(model_params, text_aligner, pitch_extractor, plbert)

_ = [nets[key].eval() for key in nets]
_ = [nets[key].to(device) for key in nets]

params_whole = torch.load("Models/LibriTTS/epochs_2nd_00020.pth", map_location='cpu')
params = params_whole['net']

for key in nets:
    if key in params:
        print('%s loaded' % key)
        try:
            nets[key].load_state_dict(params[key])
        except:
            from collections import OrderedDict
            state_dict = params[key]
            new_state_dict = OrderedDict()
            for k, v in state_dict.items():
                name = k[7:] # remove `module.`
                new_state_dict[name] = v
            # load params
            nets[key].load_state_dict(new_state_dict, strict=False)
_ = [nets[key].eval() for key in nets]

In [None]:
from Modules.diffusion.sampler import DiffusionSampler, ADPM2Sampler, KarrasSchedule

# build diffucion samples
sampler = DiffusionSampler(
    nets["diffusion"].diffusion,
    sampler=ADPM2Sampler(),
    sigma_schedule=KarrasSchedule(sigma_min=0.0001, sigma_max=3.0, rho=9.0), # empirical parameters
    clamp=False
)

In [None]:
def length_to_mask(lengths):
    mask = torch.arange(lengths.max()).unsqueeze(0).expand(lengths.shape[0], -1).type_as(lengths)
    mask = torch.gt(mask+1, lengths.unsqueeze(1))
    return mask


class Model(nn.Module):
    def __init__(self, nets: dict, model_params: dict, sampler):
        super().__init__()
        self.model_params = model_params
        self.text_encoder = nets["text_encoder"]
        self.bert = nets["bert"]
        self.bert_encoder = nets["bert_encoder"]
        self.style_encoder = nets["style_encoder"]
        self.predictor_encoder = nets["predictor_encoder"]
        self.predictor = nets["predictor"]
        self.decoder = nets["decoder"]
        self.sampler = sampler

    def forward(self, tokens, style_ref):  # ref_tokens
        device = tokens.device
        alpha = 0.3
        beta = 0.7
        diffusion_steps = 5
        embedding_scale = 1

        input_lengths = torch.LongTensor([tokens.shape[-1]]).to(device)
        text_mask = length_to_mask(input_lengths).to(device)
        # input_lengths = tokens.shape[-1]

        t_en = self.text_encoder(tokens, input_lengths, text_mask)
        bert_dur = self.bert(tokens, attention_mask=(~text_mask).int())
        d_en = self.bert_encoder(bert_dur).transpose(-1, -2)

        # ref_input_lengths = torch.LongTensor([ref_tokens.shape[-1]]).to(device)
        # ref_text_mask = length_to_mask(ref_input_lengths).to(device)
        # ref_bert_dur = self.bert(ref_tokens, attention_mask=(~ref_text_mask).int())

        s_pred = self.sampler(
            noise = torch.randn((1, 256)).unsqueeze(1).to(device),
            embedding=bert_dur,
            embedding_scale=embedding_scale,
            features=style_ref, # reference from the same speaker as the embedding
            num_steps=diffusion_steps,
        ).squeeze(1)
        
        s = s_pred[:, 128:]
        ref = s_pred[:, :128]
    
        ref = alpha * ref + (1 - alpha)  * style_ref[:, :128]
        s = beta * s + (1 - beta)  * style_ref[:, 128:]
    
        d = self.predictor.text_encoder(d_en, s, input_lengths, text_mask)
    
        x, _ = self.predictor.lstm(d)
        duration = self.predictor.duration_proj(x)
    
        duration = torch.sigmoid(duration).sum(axis=-1)
        pred_dur = torch.round(duration.squeeze()).clamp(min=1).to(torch.int64)
    
        pred_aln_trg = torch.zeros(input_lengths, pred_dur.sum())
        c_frame = 0
        for i in range(pred_aln_trg.size(0)):
            pred_aln_trg[i, c_frame:c_frame + pred_dur[i]] = 1
            c_frame += pred_dur[i]
    
        # encode prosody
        en = (d.transpose(-1, -2) @ pred_aln_trg.unsqueeze(0).to(device))
        if self.model_params["decoder"]["type"] == "hifigan":
            asr_new = torch.zeros_like(en)
            asr_new[:, :, 0] = en[:, :, 0]
            asr_new[:, :, 1:] = en[:, :, 0:-1]
            en = asr_new

        F0_pred, N_pred = self.predictor.F0Ntrain(en, s)
    
        asr = (t_en @ pred_aln_trg.unsqueeze(0).to(device))
        if self.model_params["decoder"]["type"] == "hifigan":
            asr_new = torch.zeros_like(asr)
            asr_new[:, :, 0] = asr[:, :, 0]
            asr_new[:, :, 1:] = asr[:, :, 0:-1]
            asr = asr_new

        out = self.decoder(asr, F0_pred, N_pred, ref.squeeze().unsqueeze(0))
        
        return out.squeeze()[..., :-50] # weird pulse at the end of the model, need to be fixed later

model = Model(nets, model_params, sampler)

In [None]:
class StyleModel(nn.Module):
    def __init__(self, nets: dict):
        super().__init__()
        self.to_mel = torchaudio.transforms.MelSpectrogram(
            n_mels=80, n_fft=2048, win_length=1200, hop_length=300
        )
        self.style_encoder = nets["style_encoder"]
        self.predictor_encoder = nets["predictor_encoder"]

    def preprocess(self, wave):
        mean, std = -4, 4
        mel_tensor = self.to_mel(wave)
        mel_tensor = (torch.log(1e-5 + mel_tensor.unsqueeze(0)) - mean) / std
        return mel_tensor

    def forward(self, audio):
        mel_tensor = self.preprocess(audio)
        ref_s = self.style_encoder(mel_tensor.unsqueeze(1))
        ref_p = self.predictor_encoder(mel_tensor.unsqueeze(1))

        return torch.cat([ref_s, ref_p], dim=1)


style_model = StyleModel(nets).to(device)


def compute_style(path):
    with torch.no_grad():
        audio_tensor = torch.tensor(audio).to(device)
        style_ref = style_model(audio_tensor)
    return style_ref

## Style Transfer Demo

In [None]:
path = "Demo/reference_audio/1221-135767-0014.wav"
style_ref = compute_style(path)

# reference texts to sample styles
# ref_texts = {}
# ref_texts['Happy'] = "We are happy to invite you to join us on a journey to the past, where we will visit the most amazing monuments ever built by human hands."
# ref_texts['Sad'] = "I am sorry to say that we have suffered a severe setback in our efforts to restore prosperity and confidence."
# ref_texts['Angry'] = "The field of astronomy is a joke! Its theories are based on flawed observations and biased interpretations!"
# ref_texts['Surprised'] = "I can't believe it! You mean to tell me that you have discovered a new species of bacteria in this pond?"

text = "Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech."
text = text.strip()
ps = global_phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

with torch.no_grad():
    wav = model(tokens, style_ref).cpu().numpy()

display(ipd.Audio(wav, rate=24000, normalize=False))

# for style, ref_text in ref_texts.items():
#     # wav = STinference(text, s_ref, v, diffusion_steps=10, alpha=0.5, beta=0.9, embedding_scale=1.5)

#     ref_text = ref_text.strip()
#     ps = global_phonemizer.phonemize([ref_text])
#     ps = word_tokenize(ps[0])
#     ps = ' '.join(ps)
    
#     ref_tokens = textclenaer(ps)
#     ref_tokens.insert(0, 0)
#     ref_tokens = torch.LongTensor(ref_tokens).to(device).unsqueeze(0)

#     with torch.no_grad():
#         wav = model(tokens, ref_tokens, s_ref).cpu().numpy()
    
#     print(f"{style}: ")
#     display(ipd.Audio(wav, rate=24000, normalize=False))

## ONNX Export

In [None]:
model.eval()

for param in model.parameters():
    param.requires_grad = False

# for m in model.modules():
#     # if 'instancenorm' in m.__class__.__name__.lower():
#     m.train(False)

torch.onnx.export(
    model,
    args=(tokens, style_ref),  # ref_tokens
    f="StyleTTS2.onnx",
    verbose=False,
    opset_version=16,
    training=torch.onnx.TrainingMode.EVAL,
    input_names=["tokens", "style_ref"],  # ref_tokens
    output_names=["output"],
    dynamic_axes={
        "tokens": {1: "seq_length"},
        # "ref_tokens": {1: "seq_length"},
    },
)

## Test ONNX model

In [None]:
import onnx

onnx_model = onnx.load("StyleTTS2.onnx")
onnx.checker.check_model(onnx_model)


import onnxruntime as ort

ort_session = ort.InferenceSession('StyleTTS2.onnx')

In [None]:
path = "Demo/reference_audio/1221-135767-0014.wav"
style_ref = compute_style(path)

text = "Yea, his honourable worship is within, but he hath a godly minister or two with him, and likewise a leech."
text = text.strip()
ps = global_phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

outputs = ort_session.run(None, {
    'tokens': tokens.cpu().numpy(), 
    "style_ref": style_ref.cpu().numpy(),
})
wav = outputs[0]

display(ipd.Audio(wav, rate=24000, normalize=False))

In [None]:
path = "Demo/reference_audio/1221-135767-0014.wav"
style_ref = compute_style(path)

text = "Hello hello, this is test. How are you doing?"
text = text.strip()
ps = global_phonemizer.phonemize([text])
ps = word_tokenize(ps[0])
ps = ' '.join(ps)
tokens = textclenaer(ps)
tokens.insert(0, 0)
tokens = torch.LongTensor(tokens).to(device).unsqueeze(0)

outputs = ort_session.run(None, {
    'tokens': tokens.cpu().numpy(), 
    "style_ref": style_ref.cpu().numpy(),
})
wav = outputs[0]

display(ipd.Audio(wav, rate=24000, normalize=False))