In [None]:
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence

from scipy.io.wavfile import write


def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm

In [None]:
hps = utils.get_hparams_from_file("./configs/ljs_base.json")

In [None]:
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model)#.cuda()
_ = net_g.eval()

_ = utils.load_checkpoint("./pretrained_ljs.pth", net_g, None)

In [None]:
sum(p.numel() for p in net_g.parameters())

36321072

In [None]:
stn_tst = get_text("VITS is Awesome!", hps)
with torch.no_grad():
    x_tst = stn_tst.cuda().unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).cuda()
    output = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)
    audio = output[0][0,0].data.cpu().float().numpy()

ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))

In [None]:
output

(tensor([[[-5.2680e-05, -3.5746e-04, -6.3595e-04,  ...,  5.2548e-04,
            5.0303e-04,  5.3441e-04]]], device='cuda:0'),
 tensor([[[[1., 0., 0.,  ..., 0., 0., 0.],
           [0., 1., 0.,  ..., 0., 0., 0.],
           [0., 1., 0.,  ..., 0., 0., 0.],
           ...,
           [0., 0., 0.,  ..., 0., 0., 1.],
           [0., 0., 0.,  ..., 0., 0., 1.],
           [0., 0., 0.,  ..., 0., 0., 1.]]]], device='cuda:0'),
 tensor([[[1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.,
           1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1., 1.]]],
        device='cuda:0'),
 (tensor([[[ 0.3055,  0.4209,  0.7191,  ..., -0.2389, -0.8463, -0.3579],
      

In [None]:
output[0].shape

torch.Size([1, 1, 25856])

In [None]:
output[1].shape

torch.Size([1, 1, 101, 33])

In [None]:
x_tst

tensor([[  0,  64,   0, 156,   0, 102,   0,  62,   0,  61,   0,  16,   0, 102,
           0,  68,   0,  16,   0, 156,   0,  76,   0, 158,   0,  61,   0, 138,
           0,  55,   0,   5,   0]], device='cuda:0')

In [None]:
x_tst_lengths

tensor([33], device='cuda:0')

In [None]:
def onnx_inference(text, text_lengths, noise_scale, length_scale, noise_scale_w, sid=None):
    return net_g.infer(
        text,
        text_lengths,
        sid,
        noise_scale,
        length_scale,
        noise_scale_w,
    )[:2]

In [None]:
net_g.forward = onnx_inference

In [None]:
# set dummy inputs
dummy_input_length = 100
sequences = torch.randint(low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
speaker_id = None
noise_scale = torch.FloatTensor([1.0])
length_scale = torch.FloatTensor([1.0])
noise_scale_w = torch.FloatTensor([1.0])

dummy_input = (sequences, sequence_lengths, noise_scale, length_scale, noise_scale_w, speaker_id)

In [None]:
# export to ONNX
torch.onnx.export(
    model=net_g,
    args=dummy_input,
    opset_version=15,
    f="vits.onnx",
    # verbose=verbose,
    input_names=["input", "input_lengths", "noise_scale", "length_scale", "noise_scale_w", "speaker_id"],
    output_names=["audio", "attention"],
    dynamic_axes={
        "input": {0: "batch_size", 1: "phonemes"},
        "input_lengths": {0: "batch_size"},
        "audio": {0: "batch_size", 1: "time1", 2: "time2"},
        "attention": {0: "batch_size", 1: "time1", 2: "frames", 3: "phonemes"}
    },
)

  assert t_s == t_t, "Relative attention is only available for self-attention."
  pad_length = max(length - (self.window_size + 1), 0)
  slice_start_position = max((self.window_size + 1) - length, 0)
  if pad_length > 0:
  if torch.min(inputs) < left or torch.max(inputs) > right:
  if min_bin_width * num_bins > 1.0:
  if min_bin_height * num_bins > 1.0:
  assert (discriminant >= 0).all()
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(


verbose: False, log level: Level.ERROR



  _C._jit_pass_onnx_graph_shape_type_inference(


In [None]:
!ls -lh vits.onnx

-rw-r--r-- 1 root root 109M Jul 25 04:41 vits.onnx
