/
export-onnx.py
executable file
·132 lines (107 loc) · 3.52 KB
/
export-onnx.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
#!/usr/bin/env python3
# Copyright 2023 Xiaomi Corp. (authors: Fangjun Kuang)
# pip install git+https://github.com/wenet-e2e/wenet.git
# pip install onnxruntime onnx pyyaml
# cp -a ~/open-source/wenet/wenet/transducer/search .
# cp -a ~/open-source//wenet/wenet/e_branchformer .
# cp -a ~/open-source/wenet/wenet/ctl_model .
import os
from typing import Dict
import onnx
import torch
import yaml
from onnxruntime.quantization import QuantType, quantize_dynamic
from wenet.utils.init_model import init_model
class Foo:
pass
def add_meta_data(filename: str, meta_data: Dict[str, str]):
"""Add meta data to an ONNX model. It is changed in-place.
Args:
filename:
Filename of the ONNX model to be changed.
meta_data:
Key-value pairs.
"""
model = onnx.load(filename)
for key, value in meta_data.items():
meta = model.metadata_props.add()
meta.key = key
meta.value = str(value)
onnx.save(model, filename)
class OnnxModel(torch.nn.Module):
def __init__(self, encoder: torch.nn.Module, ctc: torch.nn.Module):
super().__init__()
self.encoder = encoder
self.ctc = ctc
def forward(self, x, x_lens):
"""
Args:
x:
A 3-D tensor of shape (N, T, C)
x_lens:
A 1-D tensor of shape (N,) containing valid lengths in x before
padding. Its type is torch.int64
"""
encoder_out, encoder_out_mask = self.encoder(
x,
x_lens,
decoding_chunk_size=-1,
num_decoding_left_chunks=-1,
)
log_probs = self.ctc.log_softmax(encoder_out)
log_probs_lens = encoder_out_mask.int().squeeze(1).sum(1)
return log_probs, log_probs_lens
@torch.no_grad()
def main():
args = Foo()
args.checkpoint = "./final.pt"
config_file = "./train.yaml"
with open(config_file, "r") as fin:
configs = yaml.load(fin, Loader=yaml.FullLoader)
torch_model, configs = init_model(args, configs)
torch_model.eval()
onnx_model = OnnxModel(encoder=torch_model.encoder, ctc=torch_model.ctc)
filename = "model.onnx"
N = 1
T = 1000
C = 80
x = torch.rand(N, T, C, dtype=torch.float)
x_lens = torch.full((N,), fill_value=T, dtype=torch.int64)
opset_version = 13
onnx_model = torch.jit.script(onnx_model)
torch.onnx.export(
onnx_model,
(x, x_lens),
filename,
opset_version=opset_version,
input_names=["x", "x_lens"],
output_names=["log_probs", "log_probs_lens"],
dynamic_axes={
"x": {0: "N", 1: "T"},
"x_lens": {0: "N"},
"log_probs": {0: "N", 1: "T"},
"log_probs_lens": {0: "N"},
},
)
# https://wenet.org.cn/downloads?models=wenet&version=aishell_u2pp_conformer_exp.tar.gz
url = os.environ.get("WENET_URL", "")
meta_data = {
"model_type": "wenet_ctc",
"version": "1",
"model_author": "wenet",
"comment": "non-streaming",
"subsampling_factor": torch_model.encoder.embed.subsampling_rate,
"vocab_size": torch_model.ctc.ctc_lo.weight.shape[0],
"url": url,
}
add_meta_data(filename=filename, meta_data=meta_data)
print("Generate int8 quantization models")
filename_int8 = f"model.int8.onnx"
quantize_dynamic(
model_input=filename,
model_output=filename_int8,
op_types_to_quantize=["MatMul"],
weight_type=QuantType.QInt8,
)
if __name__ == "__main__":
main()