□ [nn-morse](https://github.com/pd0wm/nn-morse)で学習したモデルをonnx形式で保存、onnxruntimeで推論してみる

■ GoogleDriveに接続

In [None]:
from google.colab import drive
drive.mount('/content/drive/')

Mounted at /content/drive/


■ 環境構築

In [None]:
!cp drive/MyDrive/MORSE/新models-8/main8.py main.py
!cp drive/MyDrive/MORSE/新models-8/my_morse.py morse.py
!cp "drive/MyDrive/MORSE/新models-8/002200.pt" 002200.pt

!cp drive/MyDrive/MORSE/hsu-sat1.wav .

■ pytorchで作ったモデルをonnx形式で保存

1. pytorchモデルの読込

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io.wavfile
import torch
from scipy import signal

from main import Net, num_tags, prediction_to_str
from morse import ALPHABET, SAMPLE_FREQ, get_spectrogram

rate, data = scipy.io.wavfile.read("hsu-sat1.wav")

# Resample and rescale
length = len(data) / rate
new_length = int(length * SAMPLE_FREQ)

data = signal.resample(data, new_length)
data = data.astype(np.float32)
data /= np.max(np.abs(data))

# Create spectrogram
spec = get_spectrogram(data)
spec_orig = spec.copy()
spectrogram_size = spec.shape[0]

spec = torch.from_numpy(spec)
spec = spec.permute(1, 0)
spec = spec.unsqueeze(0)

# Load model
device = torch.device("cpu")
model = Net(num_tags, spectrogram_size)
model.load_state_dict(torch.load("002200.pt", map_location=device))
model.eval()


Net(
  (dense1): Linear(in_features=41, out_features=256, bias=True)
  (dense2): Linear(in_features=256, out_features=256, bias=True)
  (dense3): Linear(in_features=256, out_features=256, bias=True)
  (dense4): Linear(in_features=256, out_features=256, bias=True)
  (lstm1): LSTM(256, 256, batch_first=True)
  (dense5): Linear(in_features=256, out_features=44, bias=True)
)

2. ONNX形式で保存 : バージョン指定 & 入力変数のサイズは可変にしておく

In [None]:
!pip install onnx

Collecting onnx
  Downloading onnx-1.15.0-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (15.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m15.7/15.7 MB[0m [31m44.4 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: onnx
Successfully installed onnx-1.15.0


In [None]:
import torch.onnx

#dummy_input = torch.randn(1, 19200, spectrogram_size)
#torch.onnx.export(model, spec, "nn-morse.onnx", verbose=True)
torch.onnx.export(model,
                  spec,
                  "nn-morse.onnx",
                  export_params=True,
                  opset_version=10,
                  input_names = ['input'],
                  output_names = ['output'],
                  dynamic_axes={'input': {0: 'batch_size', 1: 'spec_len', 2: 'spec_size'}, 'output': {0: 'batch_size'}},
                  verbose=True)



■ onnruntimeから推論 part1
- 推論結果がTensor⇛numpyのarray形式となって、型変換する
- 入力はTensor型のまま

1. onnx形式のモデルを読み込む

In [None]:
!pip install onnxruntime

Collecting onnxruntime
  Downloading onnxruntime-1.17.1-cp310-cp310-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m21.0 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting coloredlogs (from onnxruntime)
  Downloading coloredlogs-15.0.1-py2.py3-none-any.whl (46 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m46.0/46.0 kB[0m [31m2.5 MB/s[0m eta [36m0:00:00[0m
Collecting humanfriendly>=9.1 (from coloredlogs->onnxruntime)
  Downloading humanfriendly-10.0-py2.py3-none-any.whl (86 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m86.8/86.8 kB[0m [31m8.8 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: humanfriendly, coloredlogs, onnxruntime
Successfully installed coloredlogs-15.0.1 humanfriendly-10.0 onnxruntime-1.17.1


In [None]:
import onnx
import onnxruntime

model = onnx.load("nn-morse.onnx")
onnx.checker.check_model(model)

2. モデルを表示,したが良く解らん

In [None]:
print(onnx.helper.printable_graph(model.graph))

graph main_graph (
  %input[FLOAT, batch_sizexspec_lenxspec_size]
) initializers (
  %dense1.bias[FLOAT, 256]
  %dense2.bias[FLOAT, 256]
  %dense3.bias[FLOAT, 256]
  %dense4.bias[FLOAT, 256]
  %dense5.bias[FLOAT, 44]
  %onnx::MatMul_108[FLOAT, 41x256]
  %onnx::MatMul_109[FLOAT, 256x256]
  %onnx::MatMul_110[FLOAT, 256x256]
  %onnx::MatMul_111[FLOAT, 256x256]
  %onnx::LSTM_131[FLOAT, 1x1024x256]
  %onnx::LSTM_132[FLOAT, 1x1024x256]
  %onnx::LSTM_133[FLOAT, 1x2048]
  %onnx::MatMul_134[FLOAT, 256x44]
) {
  %/dense1/MatMul_output_0 = MatMul(%input, %onnx::MatMul_108)
  %/dense1/Add_output_0 = Add(%dense1.bias, %/dense1/MatMul_output_0)
  %/Relu_output_0 = Relu(%/dense1/Add_output_0)
  %/dense2/MatMul_output_0 = MatMul(%/Relu_output_0, %onnx::MatMul_109)
  %/dense2/Add_output_0 = Add(%dense2.bias, %/dense2/MatMul_output_0)
  %/Relu_1_output_0 = Relu(%/dense2/Add_output_0)
  %/dense3/MatMul_output_0 = MatMul(%/Relu_1_output_0, %onnx::MatMul_110)
  %/dense3/Add_output_0 = Add(%dense3.bias, %/d

3. WAVファイルの読み込み

In [None]:
import matplotlib.pyplot as plt
import numpy as np
import scipy.io.wavfile
from scipy import signal

from main import Net, num_tags, prediction_to_str
from morse import ALPHABET, SAMPLE_FREQ, get_spectrogram

rate, data = scipy.io.wavfile.read("hsu-sat1.wav")

# Resample and rescale
length = len(data) / rate
new_length = int(length * SAMPLE_FREQ)

data = signal.resample(data, new_length)
data = data.astype(np.float32)
data /= np.max(np.abs(data))

# Create spectrogram
spec = get_spectrogram(data)
spec_orig = spec.copy()
spectrogram_size = spec.shape[0]

spec = torch.from_numpy(spec)
spec = spec.permute(1, 0)
spec = spec.unsqueeze(0)

4. 推論

In [None]:
import onnxruntime as ort

sess = ort.InferenceSession('nn-morse.onnx')
output = sess.run(['output'], {'input': spec.numpy()})


In [None]:
print ("◇output")
print (output)

output1 = np.array(output[0][0])
print ("◆y_pred⇐output1")
print (output1)

y_pred_l = np.exp(output1.tolist())
print ("◆y_pred_l")
print (y_pred_l)


◇output
[array([[[ 0.0000000e+00, -2.1061573e+01, -2.0310665e+01, ...,
         -2.4417107e+01, -2.4160883e+01, -2.4291817e+01],
        [ 0.0000000e+00, -3.8936195e+01, -3.3078651e+01, ...,
         -4.3582199e+01, -4.2318993e+01, -3.7926727e+01],
        [ 0.0000000e+00, -3.8175545e+01, -3.1148098e+01, ...,
         -3.9882446e+01, -3.8202049e+01, -3.4655579e+01],
        ...,
        [-1.0794534e-03, -1.0878096e+01, -1.1868061e+01, ...,
         -1.2584034e+01, -1.1283299e+01, -1.3148109e+01],
        [-9.0045907e-04, -1.2186127e+01, -1.1783067e+01, ...,
         -1.2425255e+01, -1.2001100e+01, -1.2931832e+01],
        [-6.7107310e-03, -9.6853209e+00, -9.5085764e+00, ...,
         -1.0000663e+01, -9.1748743e+00, -1.0216950e+01]]], dtype=float32)]
◆y_pred⇐output1
[[ 0.0000000e+00 -2.1061573e+01 -2.0310665e+01 ... -2.4417107e+01
  -2.4160883e+01 -2.4291817e+01]
 [ 0.0000000e+00 -3.8936195e+01 -3.3078651e+01 ... -4.3582199e+01
  -4.2318993e+01 -3.7926727e+01]
 [ 0.0000000e+00 -3.817554

5. 文字列に変換

In [None]:
# Convert prediction into string
# TODO: proper beam search
#m = torch.argmax(output[0], 1)
m = np.argmax(output1, 1)
print ("◆m")
print (m)

print ("変換")
print(prediction_to_str(m))


◆m
[0 0 0 ... 0 0 0]
変換
H4.18V0.08A8.14DEEEEETTTTEE  0JS1YHSHSUSAT104.19V0.10A8.46DEEEEETETTTE 9I   U  J4 V1 O: EDI       ISW               1:   I   


6. pytorchの結果と比較.同じ結果になった！

```
E8V.08A8.14D   ETTE                     OJS1YHSHSUSAT104.19V0.10A8.46DEEEEETETTTE                                             IA            M  
```