# 訓練済みVITSモデルによる推論

In [None]:
import numpy as np
import os
import torch
from TTS.tts.models.vits import Vits, VitsAudioConfig
from TTS.tts.configs.vits_config import VitsArgs, VitsConfig
from TTS.utils.audio import AudioProcessor
from TTS.tts.utils.text.tokenizer import TTSTokenizer
from TTS.tts.utils.speakers import SpeakerManager

## モデルのロード

In [None]:
audio_config = VitsAudioConfig(
    sample_rate=22050, win_length=1024, hop_length=256, num_mels=80, mel_fmin=0, mel_fmax=None
)

In [None]:
vitsArgs = VitsArgs(
    use_speaker_embedding=True,
)

config = VitsConfig(
    model_args=vitsArgs,
    audio=audio_config,
    run_name="vits_vctk",
    batch_size=16,
    eval_batch_size=16,
    batch_group_size=5,
    num_loader_workers=4,
    num_eval_loader_workers=4,
    run_eval=True,
    test_delay_epochs=-1,
    epochs=1000,
    text_cleaner="english_cleaners",
    use_phonemes=True,
    phoneme_language="en",
    compute_input_seq_cache=True,
    print_step=25,
    print_eval=False,
    mixed_precision=True,
    max_text_len=325,  # change this if you have a larger VRAM than 16GB
    cudnn_benchmark=False,
)

In [None]:
ap = AudioProcessor.init_from_config(config)

In [None]:
config.audio

In [None]:
tokenizer, config = TTSTokenizer.init_from_config(config)

In [None]:
id_file_path = "../recipes/vctk/vits/vits_vctk-November-18-2022_12+10PM-05b4ee16/speakers.pth"
speaker_manager = SpeakerManager(speaker_id_file_path=id_file_path)
print(speaker_manager.num_speakers)
print(speaker_manager.speaker_names)
print(speaker_manager.get_speakers())

In [None]:
device = torch.device("cuda:1")
model = Vits(config, ap, tokenizer, speaker_manager).to(device)

In [None]:
checkpoint_path = "../recipes/vctk/vits/vits_vctk-November-18-2022_12+10PM-05b4ee16/checkpoint_230000.pth"
model.load_checkpoint(config, checkpoint_path, eval=True)

## テキストからの推論

In [None]:
raw_text = "This cake is great. It's so delicious and moist."
# raw_text = "Many animals"
token_ids = tokenizer.text_to_ids(raw_text)
token_ids = torch.Tensor(token_ids).long().to(device)
token_ids = token_ids.unsqueeze(0)
token_ids.shape

In [None]:
speaker2id = speaker_manager.get_speakers()
print(speaker2id["VCTK_p260"])
print(speaker2id["VCTK_p310"])

In [None]:
speaker_ids = torch.Tensor([speaker2id["VCTK_p310"]]).long().to(device)
speaker_ids

In [None]:
outputs = model.inference(token_ids, aux_input={"speaker_ids": speaker_ids})

In [None]:
for k, v in outputs.items():
    print(k, v.shape)

## 合成音

In [None]:
import matplotlib.pyplot as plt
%matplotlib inline
from IPython.display import display, Audio

In [None]:
waveform = outputs["model_outputs"].squeeze().cpu().numpy()
waveform.shape

In [None]:
plt.plot(waveform);

In [None]:
display(Audio(waveform, rate=config.audio.sample_rate))

## アラインメントの可視化

In [None]:
from TTS.tts.utils.visual import plot_alignment

In [None]:
alignments = outputs["alignments"]
alignments.shape

In [None]:
align_img = alignments[0].data.cpu().numpy().T
align_img.shape

In [None]:
plot_alignment(align_img, output_fig=False)

## 時間長

In [None]:
token_ids.shape

In [None]:
outputs["durations"].shape

In [None]:
outputs["durations"].squeeze()

In [None]:
200960 / 256

In [None]:
# durationsの単位はフレーム
# 合計すると音声のフレーム長に一致する
outputs["durations"].squeeze().sum()

In [None]:
tokens = [tokenizer.characters.id_to_char(x) for x in token_ids[0].cpu().numpy()]
len(tokens)

In [None]:
durations = outputs["durations"].squeeze().cpu().numpy() * config.audio.hop_length
len(durations)

In [None]:
sum(durations)

In [None]:
positions = list(np.cumsum(durations))

In [None]:
len(positions), len(tokens)

In [None]:
plt.figure(figsize=(24, 8))
plt.plot(waveform)
for (i, x), token in zip(enumerate(positions), tokens):
    if token == "<BLNK>":
        token = "B"
    plt.axvline(x, color="r")
    plt.text(x - 250, 0.0, token)

## 訓練内の話者間での音声変換

- https://jaywalnut310.github.io/vits-demo/index.html#vc

In [None]:
p260_wav = ap.load_wav("../recipes/vctk/VCTK/wav48_silence_trimmed/p260/p260_040_mic1.flac")
p310_wav = ap.load_wav("../recipes/vctk/VCTK/wav48_silence_trimmed/p310/p310_020_mic1.flac")

display(Audio(data=p260_wav, rate=ap.sample_rate))
display(Audio(data=p310_wav, rate=ap.sample_rate))

In [None]:
speaker2id = speaker_manager.get_speakers()
print(speaker2id["VCTK_p260"])
print(speaker2id["VCTK_p310"])

In [None]:
# p260 => p310
# coqui-ttsでは reference = source の意味
reference_wav = torch.from_numpy(p260_wav).float().unsqueeze(0).to(device)
speaker_id = torch.Tensor([speaker2id["VCTK_p310"]]).long().to(device)
reference_speaker_id = torch.Tensor([speaker2id["VCTK_p260"]]).long().to(device)

converted_wav = model.inference_voice_conversion(
    reference_wav,
    speaker_id=speaker_id,
    reference_speaker_id=reference_speaker_id)
converted_wav = converted_wav.squeeze().cpu().numpy()
converted_wav.shape

In [None]:
plt.plot(converted_wav);
display(Audio(data=converted_wav, rate=ap.sample_rate))