In [134]:
import IPython.display as ipd
import torch
import commons
import utils
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
from scipy.io.wavfile import write, read
import numpy as np
import os

class vits():
    def __init__(self, checkpoint_path, config_path):
        self.hps = utils.get_hparams_from_file(config_path)
        self.spk_count = self.hps.data.n_speakers
        self.net_g = SynthesizerTrn(
            len(symbols),
            self.hps.data.filter_length // 2 + 1,
            self.hps.train.segment_size // self.hps.data.hop_length,
            n_speakers=self.hps.data.n_speakers,
            **self.hps.model).cuda()
        _ = self.net_g.eval()
        _ = utils.load_checkpoint(checkpoint_path, self.net_g, None)

    def get_text(self, text, hps, pause_duration=5):
        parts = text.split('{')
        new_text = []
        for part in parts:
            if '}' in part:
                pre_pause, post_pause = part.split('}')
                new_text.extend(text_to_sequence(pre_pause, hps.data.text_cleaners))
                # 각 문자 사이에 긴 pause 추가
                chars = list(post_pause)
                for char in chars:
                    new_text.extend(text_to_sequence(char, hps.data.text_cleaners))
                    new_text.extend([0] * pause_duration)  # Add longer pause between characters
            else:
                new_text.extend(text_to_sequence(part, hps.data.text_cleaners))
        
        if hps.data.add_blank:
            new_text = commons.intersperse(new_text, 0)
        
        text_norm = torch.LongTensor(new_text)
        return text_norm

    def infer(self, text, spk_id=0, pause_duration=5):
        ipd.clear_output()
        stn_tst = self.get_text(text, self.hps, pause_duration)
        with torch.no_grad():
            x_tst = stn_tst.cuda().unsqueeze(0)
            x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
            sid = torch.LongTensor([spk_id]).cuda()
            audio = self.net_g.infer(x_tst, x_tst_lengths, sid=sid, noise_scale=.667, noise_scale_w=0.8, length_scale=1)[0][0,0].data.cpu().float().numpy()

        # Split audio if longer than 5 seconds
        max_duration = 5  # seconds
        sampling_rate = self.hps.data.sampling_rate
        max_samples = max_duration * sampling_rate

        # Save audio segments
        audio_segments = [audio[i:i+max_samples] for i in range(0, len(audio), max_samples)]
        file_paths = []
        for idx, segment in enumerate(audio_segments):
            file_path = f'infer/test{idx+1}.wav'
            write(file_path, sampling_rate, segment)
            file_paths.append(file_path)
            ipd.display(ipd.Audio(segment, rate=sampling_rate, normalize=False))

        # Combine audio segments
        combined_audio = np.concatenate([read(file_path)[1] for file_path in file_paths])
        combined_file_path = 'infer/combined.wav'
        write(combined_file_path, sampling_rate, combined_audio)
        ipd.display(ipd.Audio(combined_audio, rate=sampling_rate, normalize=False))

In [135]:

text = '학습은 잘 마치셨나요? 좋은 결과가 있길 바래요. {1 1 라 2 3 4 5}번님'
pause_duration = 50  # 퍼즈 길이 설정

In [136]:
tts = vits('checkpoints/lasttry/G_51000.pth', 'checkpoints/lasttry/config.json')
tts.infer(text, 0)