# 訓練済みVITSモデルによる推論

In [None]:
import numpy as np
import os
import torch
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.configs.shared_configs import CharactersConfig
from TTS.tts.configs.vits_config import VitsConfig
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.tokenizer import TTSTokenizer

## モデルのロード

In [None]:
audio_config = VitsAudioConfig(
    sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
)

In [None]:
characters = CharactersConfig(
    characters_class="TTS.tts.models.vits.VitsCharacters",
    pad="<PAD>",
    characters=[
        "pau",
        "I",
        "N",
        "U",
        "a",
        "b",
        "by",
        "ch",
        "cl",
        "d",
        "dy",
        "e",
        "f",
        "g",
        "gy",
        "h",
        "hy",
        "i",
        "j",
        "k",
        "ky",
        "m",
        "my",
        "n",
        "ny",
        "o",
        "p",
        "py",
        "r",
        "ry",
        "s",
        "sh",
        "t",
        "ts",
        "u",
        "v",
        "w",
        "y",
        "z",
    ],
    punctuations=".?!",
)

In [None]:
config = VitsConfig(
    run_name="vits_jsut",
    text_cleaner="japanese_cleaners",
    use_phonemes=True,
    add_blank=True,
    phoneme_language="ja-jp",
    phonemizer="pyopenjtalk",
    characters=characters,
)

In [None]:
ap = AudioProcessor.init_from_config(config)

In [None]:
config.audio

In [None]:
tokenizer, config = TTSTokenizer.init_from_config(config)

In [None]:
device = torch.device("cuda:0")
model = Vits(config, ap, tokenizer, speaker_manager=None).to(device)

In [None]:
checkpoint_path = "../recipes/jsut/vits_tts/vits_jsut-November-22-2022_03+14PM-0705a45a/checkpoint_230000.pth"
model.load_checkpoint(config, checkpoint_path, eval=True)

## テキストからの推論

In [None]:
raw_text = "ごめんなさいね、昨日は娘が突然お世話になったみたいで。"
token_ids = tokenizer.text_to_ids(raw_text)
token_ids = torch.Tensor(token_ids).long().to(device)
token_ids = token_ids.unsqueeze(0)
token_ids

In [None]:
outputs = model.inference(token_ids)

In [None]:
for k, v in outputs.items():
    print(k, v.shape)

## 合成音

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, Audio

In [None]:
waveform = outputs["model_outputs"].squeeze().cpu().numpy()
waveform.shape

In [None]:
plt.plot(waveform);

In [None]:
display(Audio(waveform, rate=config.audio.sample_rate))

## アラインメントの可視化

In [None]:
from TTS.tts.utils.visual import plot_alignment

In [None]:
alignments = outputs["alignments"]
alignments.shape

In [None]:
align_img = alignments[0].data.cpu().numpy().T
align_img.shape

In [None]:
plot_alignment(align_img, output_fig=False)

## 時間長

In [None]:
token_ids.shape

In [None]:
outputs["durations"].shape

In [None]:
outputs["durations"].squeeze()

In [None]:
200960 / 256

In [None]:
# durationsの単位はフレーム
# 合計すると音声のフレーム長に一致する
outputs["durations"].squeeze().sum()

In [None]:
tokens = [tokenizer.characters.id_to_char(x) for x in token_ids[0].cpu().numpy()]
len(tokens)

In [None]:
durations = outputs["durations"].squeeze().cpu().numpy() * config.audio.hop_length
len(durations)

In [None]:
sum(durations)

In [None]:
positions = list(np.cumsum(durations))

In [None]:
len(positions), len(tokens)

In [None]:
plt.figure(figsize=(16, 8))
plt.plot(waveform)
for (i, x), token in zip(enumerate(positions), tokens):
    if token == "<BLNK>":
        token = "B"
    plt.axvline(x, color="r")
    plt.text(x - 250, 0.0, token)