# TTS Inference

In [None]:

import torch

from models.glow_tts_with_pitch import GlowTTSModel
from utils.data import load_speaker_emb

from nemo.collections.tts.models import HifiGanModel

In [None]:
def infer(
    spec_gen_model,
    vocoder_model,
    str_input,
    noise_scale=0.0,
    length_scale=1.0,
    speaker=None,
    speaker_embeddings=None,
    stoch_dur_noise_scale=0.8,
    stoch_pitch_noise_scale=1.0,
    pitch_scale=0.0,
):

    with torch.no_grad():
        parsed = spec_gen_model.parse(str_input)

        spectrogram = spec_gen_model.generate_spectrogram(
            tokens=parsed,
            noise_scale=noise_scale,
            length_scale=length_scale,
            speaker=speaker,
            speaker_embeddings=speaker_embeddings,
            stoch_dur_noise_scale=stoch_dur_noise_scale,
            stoch_pitch_noise_scale=stoch_pitch_noise_scale,
            pitch_scale=pitch_scale,
        )

        audio = vocoder_model.convert_spectrogram_to_audio(spec=spectrogram)

    if spectrogram is not None:
        if isinstance(spectrogram, torch.Tensor):
            spectrogram = spectrogram.to("cpu").numpy()
        if len(spectrogram.shape) == 3:
            spectrogram = spectrogram[0]
    if isinstance(audio, torch.Tensor):
        audio = audio.to("cpu").numpy()
    return spectrogram, audio

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
# load glowtts model from checkpoint
spec_gen = GlowTTSModel.load_from_checkpoint(checkpoint_path=checkpoint_path)
spec_gen = spec_gen.eval().to(device)

In [None]:
# load vocoder from checkpoint
vocoder = HifiGanModel.load_from_checkpoint(checkpoint).eval().to(device)

In [None]:
# Load speaker embeddings for conditioning
speaker_emb_dict = load_speaker_emb(spk_emb_path)

## Inference

Now that everything is set up, let's give an input that we want our models to speak

In [None]:
# Extract speaker embedding from file

audio_path = "common_voice_en_18498899.wav"
audio_path_wo = audio_path.split(".")[0]

speaker_embeddings = speaker_emb_dict.get(audio_path_wo)
speaker_embeddings = torch.from_numpy(speaker_embeddings).reshape(1, -1).to(device)

if speaker_embeddings is None:
    print("Could not load speaker embedding")

## Inference

In [None]:
# Inference hyperparameters

sr=16000
noise_scale=0.667
length_scale=1.0 #
stoch_dur_noise_scale=0.8 #0.0-1.0
stoch_pitch_noise_scale=0.8
pitch_scale=0.0
speaker=None

In [None]:
from nemo_text_processing.text_normalization.normalize import Normalizer

# initialize normalizer
normalizer = Normalizer(input_case="cased", lang="en")

In [None]:
text_to_generate = "A look of fear crossed his face, but he regained his serenity immediately."

# normalize text. necessary in case text contains numeric text, dates, and abbreviations
text_to_generate = normalizer.normalize(text_to_generate)
print(text_to_generate)

In [None]:

log_spec, audio = infer(spec_gen, vocoder, text_to_generate, 
                    noise_scale=noise_scale,
                    length_scale=length_scale,
                    speaker=speaker,
                    stoch_dur_noise_scale=stoch_dur_noise_scale,
                    stoch_pitch_noise_scale=stoch_pitch_noise_scale,
                    pitch_scale=pitch_scale,
                    speaker_embeddings=speaker_embeddings,)


In [None]:
ipd.Audio(audio, rate=sr)