In [1]:
import matplotlib.pyplot as plt
import IPython.display as ipd

import os
import json
import math
import torch
from torch import nn
from torch.nn import functional as F
from torch.utils.data import DataLoader

import commons
import utils
from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate
from models import SynthesizerTrn
from text.symbols import symbols
from text import text_to_sequence

from scipy.io.wavfile import write


def get_text(text, hps):
    text_norm = text_to_sequence(text, hps.data.text_cleaners)
    if hps.data.add_blank:
        text_norm = commons.intersperse(text_norm, 0)
    text_norm = torch.LongTensor(text_norm)
    return text_norm



In [2]:
hps = utils.get_hparams_from_file("./configs/en_au_dean2zak_base.json")

In [3]:
net_g = SynthesizerTrn(
    len(symbols),
    hps.data.filter_length // 2 + 1,
    hps.train.segment_size // hps.data.hop_length,
    **hps.model)#.cuda()
_ = net_g.eval()

_ = utils.load_checkpoint("./logs/vits-base-en-AU-Dean2Zak/G_800000.pth", net_g, None)

In [4]:
sum(p.numel() for p in net_g.parameters())

36434992

In [5]:
stn_tst = get_text("VITS is Awesome!", hps)
with torch.no_grad():
    # x_tst = stn_tst.cuda().unsqueeze(0)
    x_tst = stn_tst.unsqueeze(0)
    x_tst_lengths = torch.LongTensor([stn_tst.size(0)])#.cuda()
    output = net_g.infer(x_tst, x_tst_lengths, noise_scale=.667, noise_scale_w=0.8, length_scale=1)
    audio = output[0][0,0].data.cpu().float().numpy()

ipd.display(ipd.Audio(audio, rate=hps.data.sampling_rate, normalize=False))

[W NNPACK.cpp:64] Could not initialize NNPACK! Reason: Unsupported hardware.


In [6]:
output[0].shape

torch.Size([1, 1, 65024])

In [7]:
output[1].shape

torch.Size([1, 1, 127, 41])

In [8]:
x_tst

tensor([[ 0, 28,  0, 50,  0,  7,  0, 47,  0,  7,  0, 25,  0, 50,  0,  7,  0, 58,
          0, 24,  0,  7,  0, 59,  0, 30,  0,  7,  0, 55,  0, 24,  0, 37,  0, 20,
          0,  7,  0,  5,  0]])

In [9]:
x_tst_lengths

tensor([41])

In [10]:
def onnx_inference(text, text_lengths, noise_scale, length_scale, noise_scale_w, sid=None):
    return net_g.infer(
        text,
        text_lengths,
        sid,
        noise_scale,
        length_scale,
        noise_scale_w,
    )[:2]

In [11]:
net_g.forward = onnx_inference

In [12]:
# set dummy inputs
dummy_input_length = 100
sequences = torch.randint(low=0, high=len(symbols), size=(1, dummy_input_length), dtype=torch.long)
sequence_lengths = torch.LongTensor([sequences.size(1)])
speaker_id = None
noise_scale = torch.FloatTensor([1.0])
length_scale = torch.FloatTensor([1.0])
noise_scale_w = torch.FloatTensor([1.0])

dummy_input = (sequences, sequence_lengths, noise_scale, length_scale, noise_scale_w, speaker_id)

In [13]:
# export to ONNX
torch.onnx.export(
    model=net_g,
    args=dummy_input,
    opset_version=15,
    f="./vits-en-AU-Dean2Zak.onnx",
    # verbose=verbose,
    input_names=["input", "input_lengths", "noise_scale", "length_scale", "noise_scale_w", "speaker_id"],
    output_names=["audio", "attention"],
    dynamic_axes={
        "input": {0: "batch_size", 1: "phonemes"},
        "input_lengths": {0: "batch_size"},
        "audio": {0: "batch_size", 1: "time1", 2: "time2"},
        "attention": {0: "batch_size", 1: "time1", 2: "frames", 3: "phonemes"}
    },
)

  assert t_s == t_t, "Relative attention is only available for self-attention."
  pad_length = max(length - (self.window_size + 1), 0)
  slice_start_position = max((self.window_size + 1) - length, 0)
  if pad_length > 0:
  if torch.min(inputs) < left or torch.max(inputs) > right:
  if min_bin_width * num_bins > 1.0:
  if min_bin_height * num_bins > 1.0:
  assert (discriminant >= 0).all()


  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


verbose: False, log level: Level.ERROR



In [14]:
!ls -lh vits-en-AU-Dean2Zak.onnx

-rw-r--r-- 1 root root 109M Mar 18 03:39 vits-en-AU-Dean2Zak.onnx
