## MB-iSTFT-VITS2 inference

In [1]:
%matplotlib inline
import matplotlib.pyplot as plt
import IPython.display as ipd
import librosa

import os
import json
import math

import requests
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils_multispeker_multitone import TextAudioSpeakerToneLoader, TextAudioSpeakerToneCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence

from scipy.io.wavfile import write
import re

In [2]:
#- device setting
if torch.cuda.is_available() is True:
    device = "cuda:0"
else:
    device = "cpu"

In [3]:
def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

def vcms(inputstr, sid, tid): # multi
    fltstr = re.sub(r"[\[\]\(\)\{\}]", "", inputstr)
    #fltstr = langdetector(fltstr) #- optional for cjke/cjks type cleaners
    stn_tst = get_text(fltstr, hps)

    speed = 1
    with torch.no_grad():
        x_tst = stn_tst.to(device).unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
        sid = torch.LongTensor([sid]).to(device)
        tid = torch.LongTensor([tid]).to(device)
        audio = net_g.infer(x_tst, x_tst_lengths, sid=sid, tid=tid, noise_scale=.667, noise_scale_w=0.8, length_scale=1 / speed)[0][
            0, 0].data.cpu().float().numpy()

    ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))

In [4]:
# - paths
path_to_config = "/mnt/d/VITS100/mbank/config.json" 
path_to_model = "/mnt/d/VITS100/mbank/G_64000.pth"  

In [5]:
hps = utils.get_hparams_from_file(path_to_config)

if "use_mel_posterior_encoder" in hps.model.keys() and hps.model.use_mel_posterior_encoder == True:
    print("Using mel posterior encoder for VITS2")
    posterior_channels = 80  # vits2
    hps.data.use_mel_posterior_encoder = True
else:
    print("Using lin posterior encoder for VITS1")
    posterior_channels = hps.data.filter_length // 2 + 1
    hps.data.use_mel_posterior_encoder = False

net_g = SynthesizerTrn(
    len(symbols),
    posterior_channels,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model).to(device)
_ = net_g.eval()

_ = utils.load_checkpoint(path_to_model, net_g, None)

Using mel posterior encoder for VITS2
Multi-band iSTFT VITS2


  WeightNorm.apply(module, name, dim)
  checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')


INFO:root:Loaded checkpoint '/mnt/d/VITS100/mbank/G_64000.pth' (iteration 75)


In [18]:
# - text input
input = "бишкек "
vcms(input, 0, 1)