In [None]:
import torch
import commons
import utils
import time
import onnxruntime
import numpy as np

from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence
from scipy.io.wavfile import write

CONFIG_PATH = "./configs/vits2_ljs_nosdp.json"
MODEL_PATH = "./logs/G_64000.pth"
TEXT = "VITS-2 is Awesome!"

hps = utils.get_hparams_from_file(CONFIG_PATH)

if (
    "use_mel_posterior_encoder" in hps.model.keys()
    and hps.model.use_mel_posterior_encoder == True
):
    print("Using mel posterior encoder for VITS2")
    posterior_channels = 80  # vits2
    hps.data.use_mel_posterior_encoder = True
else:
    print("Using lin posterior encoder for VITS1")
    posterior_channels = hps.data.filter_length // 2 + 1
    hps.data.use_mel_posterior_encoder = False
    
net_g_cuda = SynthesizerTrn(
    len(symbols),
    posterior_channels,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model
).cuda()
_ = net_g_cuda.eval()

net_g_cpu = SynthesizerTrn(
    len(symbols),
    posterior_channels,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model
)
_ = net_g_cpu.eval()

_ = utils.load_checkpoint(MODEL_PATH, net_g_cuda, None)
_ = utils.load_checkpoint(MODEL_PATH, net_g_cpu, None)

sess_options = onnxruntime.SessionOptions()
onnx_model = onnxruntime.InferenceSession("onnx/vits2.onnx", sess_options=sess_options)


def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm


def inference(model, text):
    stn_tst = get_text(TEXT, hps)
    with torch.no_grad():
        x_tst = stn_tst.cuda().unsqueeze(0)
        x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
        audio = (
            model.infer(
                x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1
            )[0][0, 0]
            .data.cpu()
            .float()
            .numpy()
        )
    return audio

def onnx_inference(model, text):
    phoneme_ids = get_text(text, hps)
    text = np.expand_dims(np.array(phoneme_ids, dtype=np.int64), 0)
    text_lengths = np.array([text.shape[1]], dtype=np.int64)
    scales = np.array([0.667, 1.0, 0.8], dtype=np.float32)
    sid = None

    audio = model.run(
        None,
        {
            "input": text,
            "input_lengths": text_lengths,
            "scales": scales,
            "sid": sid,
        },
    )[0].squeeze((0, 1))
    return audio

def benchmark(model, text, model_name):
    # Warm-up inference (optional)
    for _ in range(10):
        with torch.no_grad():
            if "ONNX" in model_name:
                onnx_inference(model, text)
            else:
                inference(model, text)

    # Measure inference time
    num_iterations = 1000  # Adjust the number of iterations as needed
    total_time = 0

    for _ in range(num_iterations):
        start_time = time.time()
        with torch.no_grad():
            if "ONNX" in model_name:
                outputs = onnx_inference(model, text)
            else:
                outputs = inference(model, text)
        end_time = time.time()
        total_time += end_time - start_time

    # Calculate average inference time
    avg_inference_time = total_time / num_iterations
    print(f"Model: {model_name}")
    print(f"Average Inference Time per Iteration: {avg_inference_time:.4f} seconds")    


benchmark(net_g_cpu, TEXT, "Pytorch CPU")
benchmark(net_g_cuda, TEXT, "Pytorch GPU")
benchmark(onnx_model, TEXT, "ONNX CPU")

In [6]:
import onnxruntime as ort
ort.get_device()


'CPU'