# Package installations

In [None]:
!apt-get update -y
!apt-get install ffmpeg -y

In [None]:
!pip install -r requirements.txt 

In [None]:
!pip install git+https://github.com/neel04/NeMo.git@dev#egg=nemo_toolkit[tts]

# Setup

These are some debug flags. You can turn them on for performance benchmarking:

```py
%env TORCH_WARN=1
%env PT_XLA_DEBUG_LEVEL=1
%env PJRT_DEVICE=TPU
%env TORCHDYNAMO_VERBOSE=0
```

In [None]:
import torch.nn.functional as F
import soundfile as sf

from nemo.collections.tts.models import FastPitchModel
from nemo.collections.tts.models import HifiGanModel

In [None]:
import logging
logging.basicConfig(level = logging.INFO)
logging.getLogger("nemo_logger").setLevel(logging.ERROR)

In [None]:
import string
import random
import warnings
import time

import torch
import torch_xla
import torchvision
import torch_xla.debug.metrics as met
import torch_xla.core.xla_model as xm

from tqdm import tqdm
from typing import Optional
from IPython.display import Audio
from dataclasses import dataclass
from contextlib import ContextDecorator

device = torch_xla.device()

torch_xla.experimental.eager_mode(False)

# Model definition

In [None]:
!rm -rf ./compilation_cache/
import torch_xla.runtime as xr
xr.initialize_cache('./compilation_cache', readonly=False)

In [None]:
spec_generator = FastPitchModel.from_pretrained("nvidia/tts_en_fastpitch").to(device)
model = HifiGanModel.from_pretrained(model_name="nvidia/tts_hifigan").to(device)

Next, we convert the model weights to fp16/bf16. This utility function can also be configured to ignore certain weights which are better in full precision 

In [None]:
def convert_to_dtype(model, dtype=torch.float16):
    # Convert model parameters to FP16
    model = model.to(dtype)
    
    # Keep norm layers in FP32 (for future models)
    for name, param in model.named_parameters():
        if 'bn' in name:
            param.data = param.data.float()
            param.grad = param.grad.float()
    
    return model

model = convert_to_dtype(model)
spec_generator = convert_to_dtype(spec_generator)

Setting the model to `eval` (inferencing) mode:

In [None]:
spec_generator = spec_generator.to(device)
model = model.to(device)

spec_generator.eval()
model.eval()

## Helper functions

Here, we expose some functions for debugging & optimization reasons. One may need to partially compile some parts of the model depending on the architecture, usecase and performance. Thus we expose some utilities that we modify for better performance on TPU hardware.

In [None]:
@torch.compile(backend="openxla", fullgraph=True)
def log_to_duration(log_dur, min_dur, max_dur, mask):
    dur = torch.clamp(torch.exp(log_dur) - 1.0, min_dur, max_dur)
    dur *= mask.squeeze(2)
    return dur
    
@torch.compile(backend='openxla', fullgraph=True)
def compiled_device_transfer(x, device):
    return x.to(device)
    
def tensor_to_int(x: torch.Tensor) -> int:
    '''
    Casts tensors to integer and places them on CPU 
    '''
    return x.cpu().detach().long().item()

This is simply a context manager for easy benchmarking of torch XLA code

In [None]:
@dataclass
class BenchmarkResult:
    time_ms: float
    time_sec: float
    output: Optional[any] = None

class benchmark_xla(ContextDecorator):
    def __init__(self, name: str = "Benchmark", silent: bool = False):
        self.name = name
        self.silent = silent
        self.result: Optional[BenchmarkResult] = None
        self.output = None
        
    def __enter__(self):
        self.start_time = time.time()
        return self
        
    def __exit__(self, *exc):
        xm.mark_step()
        xm.wait_device_ops()
            
        time_taken = time.time() - self.start_time
        
        self.result = BenchmarkResult(
            time_ms=round(time_taken * 1000, 3),
            time_sec=round(time_taken, 3),
            output=self.output
        )
        
        if not self.silent:
            print(f"{self.name} - Time taken: {self.result.time_ms} ms ({self.result.time_sec} seconds)")
            
        return False

## Forward pass

This function isn't `JIT`-compiled, however it does consume a static max length. Make sure inputs aren't clipped due to the static length. A max-length of `768-1024` should be good enough for <10s of audio, suiting most use-cases.

In [None]:
Fastpitch = spec_generator.fastpitch # convenience
STATIC_MAX_LENGTH: int = 768
PARSED_PADLEN: int = 128

In [None]:
def static_regulate_len(
    durations,
    enc_out,
    pace: float = 1.0,
    group_size: int = 1,
    dur_lens: torch.tensor = None,
    max_allowed_len: int = STATIC_MAX_LENGTH,
):
    """XLA-optimized version of regulate_len that minimizes dynamic operations.
    Uses a pre-defined maximum length to avoid dynamic arange operations."""
    dtype = enc_out.dtype
    device = enc_out.device
    
    # Static division and floor operations
    reps = (durations.float() / pace + 0.5).floor()
    dec_lens = reps.sum(dim=1)
    
    # Instead of computing max_len dynamically, use max_allowed_len
    # This ensures static shape for XLA
    
    # Pre-compute pad and cumsum in a more static way
    padded_reps = torch.nn.functional.pad(reps, (1, 0, 0, 0), value=0.0)
    reps_cumsum = torch.cumsum(padded_reps, dim=1)[:, None, :]
    reps_cumsum = reps_cumsum.to(dtype=dtype, device=device)
    
    # Use static pre-computed range tensor instead of dynamic arange
    # This could be moved outside the function if needed
    static_range = torch.arange(max_allowed_len, device=device)[None, :, None]
    
    # Compute multiplication mask with fewer dynamic operations
    mult = (reps_cumsum[:, :, :-1] <= static_range) & (reps_cumsum[:, :, 1:] > static_range)
    mult = mult.to(dtype)
    
    # Final matrix multiplication
    enc_rep = torch.matmul(mult, enc_out)
    
    # Trim to actual length needed
    length_mask = torch.arange(enc_rep.shape[1], device=enc_rep.device)[None, :] < dec_lens[:, None]
    enc_rep = enc_rep * length_mask.unsqueeze(-1)

    return enc_rep, dec_lens, dec_lens

`fp_prediction` is the fused inference function which consumes a `parsed: Tensor` textual input and returns the raw (padded) audio waveform. 

We also propogate `real_len: int` output which can be used to reconstruct how much we need to slice off the raw audio waveform to remove padding.

In [None]:
@torch.compile(backend='openxla', fullgraph=True)
def fp_prediction(parsed: torch.Tensor, conditioning=None):
    with torch.no_grad():
        parsed = parsed.to(device)     
        enc_out, enc_mask = Fastpitch.encoder(input=parsed, conditioning=conditioning)
        enc_out, enc_mask = enc_out.to(torch.float16), enc_mask.to(torch.float16)
        
        # Duration prediction
        log_durs_predicted = Fastpitch.duration_predictor(enc_out, enc_mask, conditioning=conditioning)
        durs_predicted = log_to_duration(
            log_dur=log_durs_predicted, 
            min_dur=Fastpitch.min_token_duration, 
            max_dur=Fastpitch.max_token_duration, 
            mask=enc_mask
        )
        
        # Pitch prediction
        pitch_predicted = Fastpitch.pitch_predictor(enc_out, enc_mask, conditioning=conditioning)
        pitch_emb = Fastpitch.pitch_emb(pitch_predicted.unsqueeze(1))
        
        # Combine encoder output and pitch embedding
        enc_out = enc_out + pitch_emb.transpose(1, 2)
        
        len_regulated, dec_lens, real_len = static_regulate_len(durs_predicted, enc_out, pace=1.0)
        dec_out, _ = Fastpitch.decoder(input=len_regulated, seq_lens=dec_lens, conditioning=conditioning)
        dec_out = dec_out.to(torch.float16)
        
        spect = Fastpitch.proj(dec_out).transpose(1, 2) # Obtain spectrogram

        audio = model.convert_spectrogram_to_audio(spec=spect)

        audio, durs_predicted = audio.to('cpu'), durs_predicted.to('cpu')

        return audio, real_len, durs_predicted

In [None]:
def parse_text(text: str | list[str], padlen: int = PARSED_PADLEN) -> tuple[torch.Tensor, list[int]]:
    torch.set_default_device('cpu')

    pre_pad_lens = torch.zeros(len(text))
    text = [text] if isinstance(text, str) else text # wrap it in a list if required
    out = torch.zeros(len(text), padlen) # preallocate the output

    for index, t in enumerate(text):
        parsed = spec_generator.parse(t).cpu()
        pre_pad_len = parsed.shape[1]
        
        if pre_pad_len > padlen:
            print(f'WARNING: Input padding is insufficient for text size. Recommend doublind paddling length.')
            
        out[index, :] = F.pad(parsed.cpu(), (0, padlen - parsed.shape[1]), value=0).long()
        pre_pad_lens[index] = pre_pad_len

    return out.long().cpu(), pre_pad_lens

# Inference

`infer_e2e` is the exposed function that directly consumes `text: str` (CPU) performs the model forward pass (TPU) and produces the output raw (padded) `audio: Tensor` waveform (CPU). Thus, the benchmarking numbers take data-movement in account as its often the bottleneck for latency-critical applications.

In [None]:
def infer_e2e(text: list[str], parsed: torch.Tensor | None = None) -> tuple[torch.Tensor, int]:
    '''
    This function can optionally consume a directly parsed tensor
    for debugging purposes.
    '''
    pre_pad_len: int | None = None
    
    if parsed is None:
        assert type(text) is list, f'Invalid input type. Got: {type(text)} expected a list'
        parsed, pre_pad_len = parse_text(text)

    audio, real_len, durs_predicted = fp_prediction(parsed.to(device))
    
    return audio, pre_pad_len, real_len, durs_predicted

## Warmup

We do a **compilation warmup** here to reduce the chances of compilation cache misses. Re-compilations take $> 800-1200 \text{ms}$ thus we want to avoid them as much as possible. 
Note that there still may be some recompilations (which would be obvious from the latency) but with a good enough warmup plus continued usage, those issues should be alleviated as a good enough cache is constructed.

Cache resides in the directly `./compilation_cache` and should be cleared periodically.

In [None]:
def generate_random_string(length):
    chars = string.ascii_letters + string.digits + string.punctuation
    return ''.join(random.choice(chars) for _ in range(length))

def warmup_parse_text():
    for length in tqdm(range(128)):
        text = generate_random_string(length)
        try:
            with warnings.catch_warnings():
                warnings.filterwarnings("ignore")
                out = parse_text(text)
        except Exception as e:
            continue

warmup_parse_text()

In [None]:
text_batch = [
    "The bees decided to have a mutiny against their queen, But they forgot to collect the honey.",
    "When he had to picnic on the beach, he purposely put sand in other people’s food.",
    "The gruff old man sat in the back of the bait shop grumbling to himself as he scooped out a handful of worms.",
    "The Tsunami wave crashed against the raised houses and broke the pilings as if they were toothpicks."
]

In [None]:
_parsed = torch.randint(0, 500, (1, 128)).long().to(device)
raw_audio, pre_pad_len, real_len, durs_predicted = infer_e2e(text = None, parsed = _parsed)
pred_noise_duration = (real_len - durs_predicted[:, pre_pad_len:].sum()) * 256

with benchmark_xla('Warmup benchmark') as e2e_bench:
    text = "This is the ritual to lead you on; your friends would meet you when your gone."
    raw_audio, pre_pad_len, real_len, durs_predicted = infer_e2e(text = [text])
    pred_noise_duration = (real_len - durs_predicted[:, pre_pad_len.long():].sum()) * 256


with benchmark_xla('Warmup Benchmark #2') as e2e_bench:
    raw_audio, pre_pad_len, real_len, durs_predicted = infer_e2e(text_batch)
    
    pred_noise_duration = torch.tensor([
        (real_len[i] - durs_predicted[i, pre_pad_len[i].long():].sum()) * 256 
        for i in range(4)
    ], device=durs_predicted.device)

## End-to-end inference

This is the `text` we wish to convert to audio:

In [None]:
text = "The employee wanted to request his employers for a raise, but was unable to do so because he feared his immediate expulsion."

In [None]:
with benchmark_xla('End-to-end forward pass') as e2e_bench:
    raw_audio, pre_pad_len, real_len, durs_predicted = infer_e2e([text])
    pred_noise_duration = (real_len - durs_predicted[:, pre_pad_len.long():].sum()) * 256

One aspect we wish to draw attention upon is the parsing operation of the text itself (performed by `parse_text`) is often the bottleneck, taking up majority of the time-taken ($50-60\%$). The actual forward pass timing is thus multiple factors less than the timing provided above. With further optimizations (lowering the parsing operation into more performance system languages like Rust or C++) we can reduce the latency even more if required.

In [None]:
with benchmark_xla('Text Parsing') as text_parsing:
    out = parse_text(text)

Here, we provide a utility to convert this raw audio waveform to an Ipython-embedded widget for easy playback on the browser

In [None]:
audio = raw_audio.cpu().detach().numpy()[:, :tensor_to_int(pred_noise_duration)]
Audio(audio, rate=22050)

## Batched Inference

We also support batched inference with arbitrary batch sizes:

In [None]:
text = [
    "No matter how beautiful the sunset, it saddened her knowing she was one day older.",
    "As time wore on, simple dog commands turned into full paragraphs explaining why the dog couldn’t do something.",
    "She found it strange that people use their cellphones to actually talk to one another.",
    "His ultimate dream fantasy consisted of being content and sleeping eight hours in a row.",
]

In [None]:
with benchmark_xla('End-to-end forward pass') as e2e_bench:
    raw_audio, pre_pad_len, real_len, durs_predicted = infer_e2e(text)
    
    pred_noise_duration = torch.tensor([
        (real_len[i] - durs_predicted[i, pre_pad_len[i].long():].sum()) * 256 
        for i in range(4)
    ], device=durs_predicted.device)

Again, we measure how much time text parsing takes in the batched case:

In [None]:
with benchmark_xla('Text Parsing') as text_parsing:
    out = parse_text(text)

In [None]:
index: int = 0 # select the index of the audio that would be played 

audio = raw_audio.cpu().detach().numpy()[index, :tensor_to_int(pred_noise_duration[index])]
Audio(audio, rate=22050)

# Performance metrics

In [None]:
AUDIO_LENGTH_SECONDS = 8 * len(text)

time_taken = e2e_bench.result.time_ms
e2e_parsing_time_taken = text_parsing.result.time_ms
rtfx = (AUDIO_LENGTH_SECONDS / time_taken) * 1000
rtfx_without_parsing = (AUDIO_LENGTH_SECONDS / (time_taken - e2e_parsing_time_taken)) * 1000

print("\n" + "="*50)
print(f"  📊 Performance Metrics | Batch size: {len(text)}")
print("="*50)
print(f" ⏱️  Total Time       : {time_taken:>8.2f} ms")
print(f" 🚀  RTFx             : {rtfx:>8.2f}x")
print(f" 🔥  RTFx (no parse)  : {rtfx_without_parsing:>8.2f}x")
print("="*50 + "\n")