In [1]:
import sys
sys.path.append('..')

In [2]:
import datetime as dt
from pathlib import Path

import IPython.display as ipd
import numpy as np
import soundfile as sf
import torch
from tqdm.auto import tqdm

from pflow.models.pflow_tts import pflowTTS
from pflow.text import sequence_to_text, text_to_sequence
from pflow.utils.model import denormalize
from pflow.utils.utils import get_user_data_dir, intersperse

from pflow.hifigan.config import v4
from pflow.hifigan.denoiser import Denoiser
from pflow.hifigan.env import AttrDict
from pflow.hifigan.models import Generator as HiFiGAN

In [3]:
device = "cuda"

In [4]:
def load_model(checkpoint_path):
    model = pflowTTS.load_from_checkpoint(checkpoint_path, map_location=device)
    model.eval()
    return model

model = load_model("/root/pflowtts_pytorch/logs/train/en_au_dean2zak/runs/2024-03-28_06-22-33/checkpoints/last.ckpt")

In [5]:
def load_vocoder(checkpoint_path):
    h = AttrDict(v4)
    hifigan = HiFiGAN(h).to(device)
    hifigan.load_state_dict(torch.load(checkpoint_path, map_location=device)['generator'])
    _ = hifigan.eval()
    hifigan.remove_weight_norm()
    return hifigan

vocoder = load_vocoder("/root/hifi-gan/cp_hifigan/g_02400000")
denoiser = Denoiser(vocoder, mode='zeros')



Removing weight norm...


In [6]:
## Number of ODE Solver steps
n_timesteps = 10

## Changes to the speaking rate
length_scale=1.0

## Sampling temperature
temperature = 0.667

In [7]:
@torch.inference_mode()
def process_text(text: str):
    x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners3']), 0),dtype=torch.long, device=device)[None]
    x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)
    x_phones = sequence_to_text(x.squeeze(0).tolist())
    return {
        'x_orig': text,
        'x': x,
        'x_lengths': x_lengths,
        'x_phones': x_phones
    }


@torch.inference_mode()
def synthesise(text, prompt):
    text_processed = process_text(text)
    start_t = dt.datetime.now()
    output = model.synthesise(
        text_processed['x'], 
        text_processed['x_lengths'],
        n_timesteps=n_timesteps,
        temperature=temperature,
        length_scale=length_scale,
        prompt=prompt
    )
    # merge everything to one dict    
    output.update({'start_t': start_t, **text_processed})
    return output

@torch.inference_mode()
def to_waveform(mel, vocoder):
    audio = vocoder(mel).clamp(-1, 1)
    audio = denoiser(audio.squeeze(0), strength=0.00025).cpu().squeeze()
    return audio.cpu().squeeze()

In [8]:
from pflow.data.text_mel_datamodule import mel_spectrogram
import torchaudio
import glob
wav_files = sorted(glob.glob("/root/pflowtts_pytorch/en_au_dean2zak/audio/*.wav")) ## fill in the path to the LJSpeech-1.1 dataset
wav, sr = torchaudio.load(wav_files[0])

mel = mel_spectrogram(
    wav,
    2048,
    80,
    44100,
    512,
    2048,
    0,
    11025,
    center=False,
)

In [12]:
prompt = mel.to(device)

In [10]:
texts = [
    # "Hello world, how are you doing?",
    "The quick brown fox jumps over the lazy dog, while the phoneme sounds of pheasants, quails and crickets chirp in the background."
]

In [13]:
from pflow.utils.model import normalize

outputs, rtfs = [], []
rtfs_w = []
for i, text in enumerate(tqdm(texts)):
    prompt = prompt[:,:,:264]
    
    prompt = normalize(prompt, model.mel_mean, model.mel_std)
    output = synthesise(text, prompt) #, torch.tensor([15], device=device, dtype=torch.long).unsqueeze(0))
    output['waveform'] = to_waveform(output['mel'], vocoder)

    # Compute Real Time Factor (RTF) with HiFi-GAN
    t = (dt.datetime.now() - output['start_t']).total_seconds()
    rtf_w = t * 44_100 / (output['waveform'].shape[-1])

    ## Pretty print
    print(f"{'*' * 53}")
    print(f"Input text - {i}")
    print(f"{'-' * 53}")
    print(output['x_orig'])
    print(f"{'*' * 53}")
    print(f"Phonetised text - {i}")
    print(f"{'-' * 53}")
    print(output['x_phones'])
    print(f"{'*' * 53}")
    print(f"RTF:\t\t{output['rtf']:.6f}")
    print(f"RTF Waveform:\t{rtf_w:.6f}")
    rtfs.append(output['rtf'])
    rtfs_w.append(rtf_w)

    # Display the synthesised waveform
    ipd.display(ipd.Audio(output['waveform'], rate=44_100))

    ## Save the generated waveform
#     save_to_folder(i, output, OUTPUT_FOLDER)

print(f"Number of ODE steps: {n_timesteps}")
print(f"Mean RTF:\t\t\t\t{np.mean(rtfs):.6f} ± {np.std(rtfs):.6f}")
print(f"Mean RTF Waveform (incl. vocoder):\t{np.mean(rtfs_w):.6f} ± {np.std(rtfs_w):.6f}")

  0%|          | 0/1 [00:00<?, ?it/s]

*****************************************************
Input text - 0
-----------------------------------------------------
The quick brown fox jumps over the lazy dog, while the phoneme sounds of pheasants, quails and crickets chirp in the background.
*****************************************************
Phonetised text - 0
-----------------------------------------------------
_ð_ə_ _k_w_ˈ_ɪ_k_ _b_ɹ_ˈ_a_ʊ_n_ _f_ˈ_ɑ_k_s_ _d_͡_ʒ_ˈ_ʌ_m_p_s_ _ˈ_o_ʊ_v_ɚ_ _ð_ə_ _l_ˈ_e_ɪ_z_i_ _d_ˈ_ɔ_ɡ_,_ _w_ˈ_a_ɪ_l_ _ð_ə_ _f_ˈ_o_ʊ_n_i_m_ _s_ˈ_a_ʊ_n_d_z_ _ə_v_ _f_ˈ_ɛ_z_ə_n_t_s_,_ _k_w_ˈ_e_ɪ_l_z_ _ˈ_æ_n_d_ _k_ɹ_ˈ_ɪ_k_ɪ_t_s_ _t_͡_ʃ_ˈ_ɚ_p_ _ˈ_ɪ_n_ _ð_ə_ _b_ˈ_æ_k_ɡ_ɹ_ˌ_a_ʊ_n_d_._
*****************************************************
RTF:		0.014966
RTF Waveform:	0.023574


Number of ODE steps: 10
Mean RTF:				0.014966 ± 0.000000
Mean RTF Waveform (incl. vocoder):	0.023574 ± 0.000000
