In [6]:
import torch
import phonemizer

from models.mixer_tts import MixerTTSModel
from vocoder.vocos.pretrained import Vocos
from utils.lj_dataset import symbols_to_id

import IPython

global_phonemizer_en = phonemizer.backend.EspeakBackend(language='en-us', preserve_punctuation=True, with_stress=True)
def text_to_ids_en(text: str):
    phonemes_ipa = ' ' + global_phonemizer_en.phonemize([text])[0] + ' '
    phonemes_ids = [symbols_to_id[s] for s in phonemes_ipa if s in symbols_to_id]
    phonemes_ids = torch.LongTensor(phonemes_ids)
    return phonemes_ids

def pitch_trf(mul: float = 1, add: float = 0):
    def _pitch_trf(pitch_pred, enc_mask_sum, mean, std):
        # print(pitch_pred, enc_mask_sum, mean, std)
        return mul*pitch_pred + add
    return _pitch_trf

device = 'cuda' if torch.cuda.is_available() else 'cpu'

In [11]:
# 22.05kHz: https://huggingface.co/BSC-LT/vocos-mel-22khz
# 44.1kHz:  https://huggingface.co/patriotyk/vocos-mel-hifigan-compat-44100khz

sample_rate = [22050, 44100][1]

if sample_rate == 22050:
    vocos = Vocos.from_pretrained("BSC-LT/vocos-mel-22khz")
elif sample_rate == 44100:
    vocos = Vocos.from_pretrained("patriotyk/vocos-mel-hifigan-compat-44100khz")
    
vocos = vocos.to(device)

  state_dict = torch.load(model_path, map_location="cpu")


In [12]:
dim = [384, 128, 80][2]
ckpt = torch.load(f'./pretrained/mixer_lj_{dim}.pth', weights_only=True, 
                  map_location='cpu')

model = MixerTTSModel(**ckpt['net_config']).to(device)
model.load_state_dict(ckpt['model'])
model.eval();

n_params = sum(p.numel() for p in model.parameters())
print(f'loaded Mixer-TTS dim: {dim} nparams: {n_params:,.0f}')


loaded Mixer-TTS dim: 80 nparams: 1,742,803


In [13]:
sentences = [
    "This paper describes Mixer-TTS, a non-autoregressive model for mel-spectrogram generation.",
    "The model is based on the MLP-Mixer architecture adapted for speech synthesis.",
    "The basic Mixer-TTS contains pitch and duration predictors, with the latter being trained with an unsupervised TTS alignment framework."
    ]

phonemes = text_to_ids_en(sentences[0])

phonemes_len = torch.LongTensor([len(phonemes)])


# (mel_out, dec_lens, dur_pred, pitch_pred, energy_pred) \
mel_spec = model.infer(phonemes[None,:].to(device),
                    phonemes_len.to(device),
                    pace=1,
                    pitch_transform=pitch_trf(mul=1, add=0)
                    )

wave = vocos.decode_mel(mel_spec.transpose(1,2), denoise=0.003)
wave = wave / wave.abs().max()

IPython.display.Audio(data=0.7*wave.cpu(), rate=sample_rate, normalize=False)