In [3]:
# import sys
# import os
# sys.path.append(os.getcwd())
# print(os.getcwd())

import sys
import os
sys.path.append(os.path.expanduser('~/mms-onnx/MMS'))
sys.path.append(os.path.expanduser('~/mms-onnx/MMS/vits'))
# print("sys.path: \n", sys.path)
for i in sys.path:
    print(i)

/home/ubuntu/miniconda3/envs/onnx-conda-py36/lib/python38.zip
/home/ubuntu/miniconda3/envs/onnx-conda-py36/lib/python3.8
/home/ubuntu/miniconda3/envs/onnx-conda-py36/lib/python3.8/lib-dynload

/home/ubuntu/miniconda3/envs/onnx-conda-py36/lib/python3.8/site-packages
/home/ubuntu/mms-onnx/MMS
/home/ubuntu/mms-onnx/MMS
/home/ubuntu/mms-onnx/MMS/vits


In [5]:
import collections
import os
from typing import Any, Dict

import onnx
import torch
from vits import commons, utils
from vits.models import SynthesizerTrn

class OnnxModel(torch.nn.Module):
    def __init__(self, model: SynthesizerTrn):
        super().__init__()
        self.model = model

    def forward(
        self,
        x,
        x_lengths,
        noise_scale=0.667,
        length_scale=1.0,
        noise_scale_w=0.8,
    ):
        return self.model.infer(
            x=x,
            x_lengths=x_lengths,
            noise_scale=noise_scale,
            length_scale=length_scale,
            noise_scale_w=noise_scale_w,
        )[0]

In [6]:
def add_meta_data(filename: str, meta_data: Dict[str, Any]):
    """Add meta data to an ONNX model. It is changed in-place.

    Args:
      filename:
        Filename of the ONNX model to be changed.
      meta_data:
        Key-value pairs.
    """
    model = onnx.load(filename)
    for key, value in meta_data.items():
        meta = model.metadata_props.add()
        meta.key = key
        meta.value = str(value)

    onnx.save(model, filename)

In [25]:
def load_vocab():
    return [
        x.replace("\n", "") for x in open("vocab.txt", encoding="utf-8").readlines()
    ]


@torch.no_grad()
def main():
    hps = utils.get_hparams_from_file("config.json")
    is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
    if is_uroman:
        raise ValueError("We don't support uroman!")

    symbols = load_vocab()

    # Now generate tokens.txt
    all_upper_tokens = [i.upper() for i in symbols]
    duplicate = set(
        [
            item
            for item, count in collections.Counter(all_upper_tokens).items()
            if count > 1
        ]
    )

    print("generate tokens.txt")

    with open("tokens.txt", "w", encoding="utf-8") as f:
        for idx, token in enumerate(symbols):
            f.write(f"{token} {idx}\n")

            # both upper case and lower case correspond to the same ID
            if (
                token.lower() != token.upper()
                and len(token.upper()) == 1
                and token.upper() not in duplicate
            ):
                f.write(f"{token.upper()} {idx}\n")

    net_g = SynthesizerTrn(
        len(symbols),
        hps.data.filter_length // 2 + 1,
        hps.train.segment_size // hps.data.hop_length,
        **hps.model,
    )
    net_g.cpu()
    _ = net_g.eval()

    _ = utils.load_checkpoint("G_100000.pth", net_g, None)

    model = OnnxModel(net_g)

    # x = torch.randint(low=1, high=10, size=(50,), dtype=torch.int64)
    # x = torch.randint(low=1, high=10, size=(2,), dtype=torch.int64)
    x = torch.randint(low=1, high=10, size=(10,), dtype=torch.int64)
    # x = torch.randint(low=1, high=10, size=(5,), dtype=torch.int64)
    x = x.unsqueeze(0)

    x_length = torch.tensor([x.shape[1]], dtype=torch.int64)
    
    print("x: \n", x)
    print("x: ", x.shape)
    print("x_length shape: ", x_length.shape)
    print("x_length: ", x_length)
    
    
    noise_scale = torch.tensor([1], dtype=torch.float32)
    length_scale = torch.tensor([1], dtype=torch.float32)
    noise_scale_w = torch.tensor([1], dtype=torch.float32)

    opset_version = 13

    filename = "model.onnx"

    torch.onnx.export(
        model,
        (x, x_length, noise_scale, length_scale, noise_scale_w),
        filename,
        opset_version=opset_version,
        input_names=[
            "x",
            "x_length",
            "noise_scale",
            "length_scale",
            "noise_scale_w",
        ],
        output_names=["y"],
        dynamic_axes={
            "x": {0: "N", 1: "L"},  # n_audio is also known as batch_size
            "x_length": {0: "N"},
            "y": {0: "N", 2: "L"},
        },
    )
    meta_data = {
        "model_type": "vits",
        "comment": "mms",
        "url": "https://huggingface.co/facebook/mms-tts/tree/main",
        "add_blank": int(hps.data.add_blank),
        "language": os.environ.get("language", "unknown"),
        "frontend": "characters",
        "n_speakers": int(hps.data.n_speakers),
        "sample_rate": hps.data.sampling_rate,
    }
    print("meta_data", meta_data)
    add_meta_data(filename=filename, meta_data=meta_data)


In [26]:
main()

generate tokens.txt


INFO:root:Loaded checkpoint 'G_100000.pth' (iteration 6251)
x: 
 tensor([[8, 2, 4, 9, 7, 8, 3, 8, 6, 4]])
x:  torch.Size([1, 10])
x_length shape:  torch.Size([1])
x_length:  tensor([10])


  assert t_s == t_t, "Relative attention is only available for self-attention."
  pad_length = max(length - (self.window_size + 1), 0)
  slice_start_position = max((self.window_size + 1) - length, 0)
  if pad_length > 0:
  if torch.min(inputs) < left or torch.max(inputs) > right:
  if min_bin_width * num_bins > 1.0:
  if min_bin_height * num_bins > 1.0:
  assert (discriminant >= 0).all()
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_node_shape_type_inference(node, params_dict, opset_version)
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(
  _C._jit_pass_onnx_graph_shape_type_inference(


meta_data {'model_type': 'vits', 'comment': 'mms', 'url': 'https://huggingface.co/facebook/mms-tts/tree/main', 'add_blank': 1, 'language': 'unknown', 'frontend': 'characters', 'n_speakers': 0, 'sample_rate': 16000}
