In [None]:
import datetime as dt
from pathlib import Path

import IPython.display as ipd
import numpy as np
import soundfile as sf
import torch
from tqdm import tqdm

# Matcha imports
from matcha.models.matcha_tts import MatchaTTS
from matcha.text import sequence_to_text, text_to_sequence
from matcha.utils.model import denormalize
from matcha.utils.utils import get_user_data_dir, intersperse

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

## Load Matcha-TTS

In [None]:
@torch.inference_mode()
def process_text(text: str):
    x = torch.tensor(intersperse(text_to_sequence(text, ['english_cleaners2'])[0], 0),dtype=torch.long, device=device)[None]
    x_lengths = torch.tensor([x.shape[-1]],dtype=torch.long, device=device)
    x_phones = sequence_to_text(x.squeeze(0).tolist())
    return {
        'x_orig': text,
        'x': x,
        'x_lengths': x_lengths,
        'x_phones': x_phones
    }

from vocos import Vocos
vocoder = Vocos.from_pretrained("charactr/vocos-mel-24khz")
@torch.inference_mode()
def to_waveform(mel, vocoder):
    audio = vocoder.decode(mel)
    return audio.cpu().squeeze()

## Synthesis

LJ Speech

In [None]:
checkpoint_path = "./logs/MatchaTTS-sfm-ljspeech.ckpt"
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
model.eval()

n_timesteps = 10
length_scale = 1.0
temperature = 0.667
alpha = 2.5
solver = "euler"

# folder = "xxx"
# import os
# if os.path.exists(folder):
#     os.system(f"rm -r {folder}")
# os.makedirs(folder, exist_ok=True)

with open("./data/LJSpeech/ljs_audio_text_test_filelist.txt", "r") as f:
    test_data = f.readlines()

SEED = 1234
import random
random.seed(SEED)
import numpy as np
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

tps = 0. 
sigma_ps = 0.
for d in tqdm(test_data):
    filename = d.split("|")[0].split("/")[-1]
    text = d.split("|")[-1].strip("\n")
    text_processed = process_text(text)
    output, tp, sigma_p = model.synthesise(
        text_processed['x'], 
        text_processed['x_lengths'],
        n_timesteps=n_timesteps,
        temperature=temperature,
        alpha=alpha,
        solver=solver,
        spks=None,
        length_scale=length_scale
    )
    output['waveform'] = to_waveform(output['mel'].cpu(), vocoder)
    break
#   tps += tp 
#   sigma_ps += sigma_p
#   sf.write(f'{folder}/{filename}', output['waveform'], 24000, 'PCM_24')

# tps = round(tps/len(test_data), 8)
# sigma_ps = round(sigma_ps/len(test_data), 8)
# print(tps, sigma_ps)

ipd.display(ipd.Audio(output['waveform'], rate=24000))

VCTK

In [None]:
checkpoint_path = "./logs/MatchaTTS-sfm-vctk.ckpt"
model = MatchaTTS.load_from_checkpoint(checkpoint_path, map_location=device)
model.eval()

n_timesteps = 10
length_scale = 1.0
temperature = 0.667
alpha = 3.5
solver = "euler"

# folder = "xxx"
# import os
# if os.path.exists(folder):
#     os.system(f"rm -r {folder}")
# os.makedirs(folder, exist_ok=True)

with open("./data/VCTK/vctk_audio_sid_text_test_filelist.txt", "r") as f:
    test_data = f.readlines()

SEED = 1234
import random
random.seed(SEED)
import numpy as np
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

tps = 0. 
sigma_ps = 0.
for d in tqdm(test_data):
    filename = d.split("|")[0].split("/")[-1]
    text = d.split("|")[-1].strip("\n")
    spk = int(d.split("|")[1])
    text_processed = process_text(text)
    output, tp, sigma_p = model.synthesise(
        text_processed['x'], 
        text_processed['x_lengths'],
        n_timesteps=n_timesteps,
        temperature=temperature,
        alpha=alpha,
        solver=solver,
        spks=torch.tensor([spk], device=device, dtype=torch.long),
        length_scale=length_scale
    )
    output['waveform'] = to_waveform(output['mel'].cpu(), vocoder)
    break
#   tps += tp 
#   sigma_ps += sigma_p
#   sf.write(f'{folder}/{filename}', output['waveform'], 24000, 'PCM_24')

# tps = round(tps/len(test_data), 8)
# sigma_ps = round(sigma_ps/len(test_data), 8)
# print(tps, sigma_ps)

ipd.display(ipd.Audio(output['waveform'], rate=24000))