<a href="https://colab.research.google.com/github/forMwish/MyDeepLearn/blob/master/test_lstm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 1. 环境配置



In [None]:
# 挂载 gdrive，选择
from google.colab import drive
import os

gdrive_path = '/gdrive'
drive.mount(gdrive_path, force_remount=True)

os.chdir("%s/MyDrive"%gdrive_path)
try:
    os.mkdir("./test_lstm")
    os.chdir("./test_lstm")
except:
    os.chdir("./test_lstm")
    os.system("rm ./*")

# 解决 matplot 相关问题
os.system("pip uninstall matplotlib")
os.system("pip install matplotlib==3.1.3")

os.system("pip install onnx")
os.system("pip install onnxruntime")

Mounted at /gdrive


0

In [None]:
# 其它配置
import torch
import numpy as np
import matplotlib.pyplot as plt
import pandas as pd
import random

# notebook 设置tag补全
%config Completer.use_jedi = False

# 优先使用 gpu 设备
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
print("use device:", device)

# pyplot 使用黑暗模式
plt.style.use("default")
# plt.style.use("dark_background")

# pytorch 随机种子固定
torch.manual_seed(0)
if torch.cuda.is_available():
    torch.cuda.manual_seed_all(0)

# numpy 随机种子固定
np.random.seed(0)

# python 随机种子固定
random.seed(0)

use device: cpu


  if __name__ == '__main__':


# 2. pytorch 部分
从 ocr_0_for_test.onnx 模型中提取相关参数，反量化后加载到 pytorch lstm 中 <br>
从 rnn_in.npy 加载输入 x <br>
从 fc_in.npy 加载对比的输出 y <br>



In [None]:
# pytorch 构建测试demo，并导出 onnx
from torch import nn

class TestLSTM(nn.Module):
    def __init__(self, input_size, hidden_size, bidirectional):
        super().__init__()
        self.lstm = torch.nn.LSTM(input_size=input_size, hidden_size=hidden_size, 
                          bias=True, batch_first=True, bidirectional=bidirectional)

    def forward(self, X):
        y = self.lstm(X)
        return y


In [None]:
# load data
x = torch.tensor(np.load("../ocr_lstm/rnn_in.npy"), dtype=torch.float32)
y = torch.tensor(np.load("../ocr_lstm/fc_in.npy"), dtype=torch.float32)

import onnx

quant_model = onnx.load("../ocr_lstm/ocr_0_for_test.onnx")
for tensor in quant_model.graph.initializer:
    if "bias_hh_l0" in tensor.name:
        print(tensor.name)

In [None]:
# load para from quant model

def find_tensors_by_name_suffix(name, initializer):
    """ 在 tensor 中查找 name == tensor.name[-len(name):] 的 tensor 
        返回所有符合条件的 tensor_list
    """
    tensor_list=[]
    for tensor in initializer:
        if name == tensor.name[-len(name):]:
            tensor_list.append(tensor)
    return tensor_list

def get_tensor_from_quant_model(name, quant_model):
    """获取参数以及对应 scale，并返回
    """
    ret = find_tensors_by_name_suffix(name, quant_model.graph.initializer)
    assert len(ret) == 1
    tensor = torch.tensor(ret[0].float_data, dtype=torch.float32).reshape(list(ret[0].dims))
    # print(name)
    # print(tensor.shape)

    key = f"{name}_quant_scale"
    ret = find_tensors_by_name_suffix(key, quant_model.graph.initializer)
    assert len(ret) == 1
    scale = torch.tensor(ret[0].float_data, dtype=torch.float32).reshape(list(ret[0].dims))
    # print(scale.shape)

    # # 特殊处理 bias_ih_l0[512:] bias_ih_l0_reverse[512:] 为 0
    # if name in ["bias_ih_l0", "bias_ih_l0_reverse"]:
    #     print("=========", name, tensor.shape)
    #     tensor[:, 512:] = 0

    return tensor, scale

def get_state_dict_from_quant_model(keys, quant_model):
    state_dict = {}
    for name in keys:
        # print(name)
        tensor, scale = get_tensor_from_quant_model(name, quant_model)
        if name in ["weight_ih_l0", "weight_hh_l0", "weight_ih_l0_reverse", "weight_hh_l0_reverse"]:
            scale = scale.reshape((-1, 1))
            tensor.squeeze_()
            # print("tensor:", tensor.shape)
            # print("scale:", scale.shape)
            tensor = tensor*scale
            state_dict[name] = tensor
        if name in ["bias_ih_l0", "bias_ih_l0_reverse"]:
            tensor.squeeze_()
            # print("====tensor:", tensor.shape)
            # print("====scale:", scale.shape)
            tensor = tensor*scale
            # print(tensor[511:])
            assert len(tensor.shape) == 1
            state_dict[name] = tensor[:int(tensor.shape[0]/2)]
            # print("====0tensor:", state_dict[name].shape)
            new_name = name.replace("bias_ih_l0", "bias_hh_l0")
            state_dict[new_name] = tensor[int(tensor.shape[0]/2):]
            # print("====1tensor:", state_dict[new_name].shape)
    return state_dict


model = TestLSTM(input_size=256, hidden_size=128, bidirectional=True)

# keys = ["weight_ih_l0", "weight_hh_l0", "bias_ih_l0", "bias_hh_l0"]
keys = ["weight_ih_l0", "weight_hh_l0", "bias_ih_l0", 
        "weight_ih_l0_reverse", "weight_hh_l0_reverse", "bias_ih_l0_reverse"]
new_state_dict = get_state_dict_from_quant_model(keys, quant_model)

model.lstm.load_state_dict(new_state_dict, strict=True)


# for i in model.lstm.named_parameters():
#     print(i)
#     break

<All keys matched successfully>

In [None]:
# run
model.eval()
with torch.no_grad():
    y_hat, status = model(x)

In [None]:
# 误差计算
o_scale = 128

y_quant = np.clip((y.squeeze_().numpy()*o_scale), -128, 127).round().astype(np.int8)
y_hat_quant = np.clip((y_hat.squeeze_().numpy()*o_scale), -128, 127).round().astype(np.int8)
error = np.abs(y_hat_quant.astype(np.int32) - y_quant.astype(np.int32))
assert len(error.shape)==2
print(f"error_quant:\n\tmax:{error.max()} index:({int(error.argmax()/error.shape[1])}, {error.argmax()%error.shape[1]})")
print(np.sum(error))

np.save("error_quant", error)
np.save("y_quant", y_quant)
np.save("y_hat_quant", y_hat_quant)

error_quant:
	max:3 index:(153, 168)
4315


# 3. onnx 部分

In [None]:
# 转换为 onnx
torch.onnx.export(model, x, "test_lstm_tmp.onnx", opset_version=12, 
                  input_names=['input'], output_names=['output'])

  "or define the initial states (h0/c0) as inputs of the model. ")


In [None]:
# Load the onnx model
import onnx

onnx_model = onnx.load("test_lstm_tmp.onnx")
onnx_model = onnx.shape_inference.infer_shapes(onnx_model)
onnx.checker.check_model(onnx_model)
onnx.save(onnx_model, "test_lstm_tmp.onnx")

In [None]:
# Create inference session
import onnxruntime as ort
import numpy as np

ort_sess = ort.InferenceSession("test_lstm_tmp.onnx")
outputs = ort_sess.run(None, {'input': x.numpy()})
# print("onnxruntime output:", outputs)
y_hat_onnx = outputs[0]
# print("y_hat_onnx.shape:", y_hat_onnx.shape)
# print("y.shape:", y.numpy().shape)

y_quant = np.clip((y.squeeze_().numpy()*o_scale), -128, 127).round().astype(np.int8)
y_hat_onnx_quant = np.clip(np.squeeze(y_hat_onnx*o_scale), -128, 127).round().astype(np.int8)

error_onnx = np.abs(y_hat_onnx_quant.astype(np.int32) - y_quant.astype(np.int32))

print("error_onnx.shape:", error_onnx.shape)

print(f"error_onnx_quant:\n\tmax:{error_onnx.max()} index:({int(error_onnx.argmax()/error_onnx.shape[1])}, {error_onnx.argmax()%error_onnx.shape[1]})")
print(np.sum(error_onnx))

np.save("error_onnx_quant", error_onnx)
np.save("y_quant", y_quant)
np.save("y_hat_onnx_quant", y_hat_onnx_quant)


error_onnx.shape: (512, 256)
error_onnx_quant:
	max:3 index:(153, 168)
4314


# 4. 比较 pytorch 和 onnx 的参数分布
验证是否参数排布如猜想的：pytorch 为 ifco 排布，而 onnx 为 iofc 排布

In [None]:
# onnx_W
# onnx_R
# onnx_B

import struct

def convert_bytes_to_list(input:bytes, format:str):
    """ bytes 转 list
    """
    iter = struct.iter_unpack(format, input)
    data_list = []
    for data in iter:
        data_list.append(data)
    return data_list

def convert_initializer_to_tensor(name:str, initializer):
    """ 从 initializer 中查找对应 name 的 tensor， 并转换为 pytorch.tensor
        未找到则返回 None
    """
    for tensor in initializer:
        print(tensor.name)
        if tensor.name == name:
            shape=list(tensor.dims)
            print(shape)
            print(tensor.data_type)

            if tensor.data_type == 1:
                assert len(tensor.float_data) == 0
                ret = convert_bytes_to_list(tensor.raw_data, "<f")
                return torch.tensor(ret, dtype=torch.float32).reshape(shape)
            else:
                assert False
    
    return None

for tensor in onnx_model.graph.initializer:
    if tensor.name == "onnx::LSTM_194":
        onnx_W = convert_initializer_to_tensor(tensor.name, onnx_model.graph.initializer)
        print(f"onnx_W:\n  shape:{list(onnx_W.shape)}")
        print(f"  onnx_W[:10]: {onnx_W.reshape(-1)[:10]}")

        # print(new_state_dict.keys())
        torch_wi = new_state_dict["weight_ih_l0"]
        h_size = int(torch_wi.shape[0]/4)
        torch_wi_0 = torch.cat((torch_wi[:h_size, :], 
                               torch_wi[3*h_size:, :],
                               torch_wi[h_size:2*h_size, :],
                               torch_wi[2*h_size:3*h_size, :]
                               ), dim=0)
        torch_wi = new_state_dict["weight_ih_l0_reverse"]
        torch_wi_1 = torch.cat((torch_wi[:h_size, :], 
                               torch_wi[3*h_size:, :],
                               torch_wi[h_size:2*h_size, :],
                               torch_wi[2*h_size:3*h_size, :]
                               ), dim=0)

        torch_w = torch.cat((
                    torch_wi_0.unsqueeze(dim=0),
                    torch_wi_1.unsqueeze(dim=0),
                    ), dim=0)
        
        print(f"torch_w:\n  shape:{torch_w.shape}")
        print(f"  torch_w[:10]: {torch_w.reshape(-1)[:10]}")
        
        error = np.abs((onnx_W - torch_w).numpy()).sum()

        # b0 = 2
        # b1 = 2
        # error = np.abs((onnx_W[0,b0*128:(b0+1)*128,:] - torch_w[0,b1*128:(b1+1)*128,:]).numpy()).sum()
        print(f"onnx_W error:{error}\n")

    elif tensor.name == "onnx::LSTM_195":
        onnx_R = convert_initializer_to_tensor(tensor.name, onnx_model.graph.initializer)
        print(f"onnx_R:\n  shape:{list(onnx_R.shape)}")
        print(f"  onnx_R[:10]: {onnx_R.reshape(-1)[:10]}")

        torch_ri = new_state_dict["weight_hh_l0"]
        h_size = int(torch_ri.shape[0]/4)
        torch_ri_0 = torch.cat((torch_ri[:h_size, :], 
                               torch_ri[3*h_size:, :],
                               torch_ri[h_size:2*h_size, :],
                               torch_ri[2*h_size:3*h_size, :]
                               ), dim=0)
        
        torch_ri = new_state_dict["weight_hh_l0_reverse"]
        torch_ri_1 = torch.cat((torch_ri[:h_size, :], 
                               torch_ri[3*h_size:, :],
                               torch_ri[h_size:2*h_size, :],
                               torch_ri[2*h_size:3*h_size, :]
                               ), dim=0)
        torch_r = torch.cat((
                    torch_ri_0.unsqueeze(dim=0),
                    torch_ri_1.unsqueeze(dim=0),
                    ), dim=0)
        
        print(f"torch_r:\n  shape:{torch_r.shape}")
        print(f"  onnx_R[:10]: {onnx_R.reshape(-1)[:10]}")

        error = np.abs((onnx_R - torch_r).numpy()).sum()
        print(f"onnx_R error:{error}\n")

    elif tensor.name == "onnx::LSTM_193":
        onnx_B = convert_initializer_to_tensor(tensor.name, onnx_model.graph.initializer)
        print("onnx_B:\n  shape:", list(onnx_B.shape))
        print(f"  onnx_B[:10]: {onnx_B.reshape(-1)[:10]}")

        torch_bi = new_state_dict["bias_ih_l0"]
        h_size = int(torch_bi.shape[0]/4)
        torch_bi_0 = torch.cat((torch_bi[:h_size], 
                               torch_bi[3*h_size:],
                               torch_bi[h_size:2*h_size],
                               torch_bi[2*h_size:3*h_size]
                               ), dim=0)
        
        torch_bi = new_state_dict["bias_hh_l0"]
        torch_bi_1 = torch.cat((torch_bi[:h_size], 
                               torch_bi[3*h_size:],
                               torch_bi[h_size:2*h_size],
                               torch_bi[2*h_size:3*h_size]
                               ), dim=0)
        
        torch_bi = new_state_dict["bias_ih_l0_reverse"]
        torch_bi_2 = torch.cat((torch_bi[:h_size], 
                               torch_bi[3*h_size:],
                               torch_bi[h_size:2*h_size],
                               torch_bi[2*h_size:3*h_size]
                               ), dim=0)
        
        torch_bi = new_state_dict["bias_hh_l0_reverse"]
        torch_bi_3 = torch.cat((torch_bi[:h_size], 
                               torch_bi[3*h_size:],
                               torch_bi[h_size:2*h_size],
                               torch_bi[2*h_size:3*h_size]
                               ), dim=0)
        

        torch_b_0 = torch.cat((
                    torch_bi_0,
                    torch_bi_1,
                    ), dim=0).unsqueeze(0)
        torch_b_1 = torch.cat((
                    torch_bi_2,
                    torch_bi_3,
                    ), dim=0).unsqueeze(0)

        torch_b = torch.cat((
                    torch_b_0,
                    torch_b_1,
                    ), dim=0)
        print("torch_b:\n  shape:", list(torch_b.shape))
        print(f"  torch_b[:10]: {torch_b.reshape(-1)[:10]}")

        error = np.abs((onnx_B - torch_b).numpy()).sum()
        print(f"onnx_R error:{error}\n")
    else:
        continue

onnx::LSTM_193
[2, 1024]
1
onnx_B:
  shape: [2, 1024]
  onnx_B[:10]: tensor([ 0.0063,  0.0000,  0.0293,  0.0142,  0.0195, -0.0005,  0.0210,  0.0112,
         0.0137, -0.0181])
torch_b:
  shape: [2, 1024]
  torch_b[:10]: tensor([ 0.0063,  0.0000,  0.0293,  0.0142,  0.0195, -0.0005,  0.0210,  0.0112,
         0.0137, -0.0181])
onnx_R error:0.0

onnx::LSTM_193
onnx::LSTM_194
[2, 512, 256]
1
onnx_W:
  shape:[2, 512, 256]
  onnx_W[:10]: tensor([ 0.0664,  0.0273,  0.0078,  0.0000,  0.0234, -0.0039, -0.0234,  0.0039,
        -0.0586,  0.0273])
torch_w:
  shape:torch.Size([2, 512, 256])
  torch_w[:10]: tensor([ 0.0664,  0.0273,  0.0078,  0.0000,  0.0234, -0.0039, -0.0234,  0.0039,
        -0.0586,  0.0273])
onnx_W error:0.0

onnx::LSTM_193
onnx::LSTM_194
onnx::LSTM_195
[2, 512, 128]
1
onnx_R:
  shape:[2, 512, 128]
  onnx_R[:10]: tensor([ 0.0234,  0.0156, -0.0156,  0.0078,  0.0156,  0.0117, -0.0117,  0.0117,
        -0.0391,  0.0352])
torch_r:
  shape:torch.Size([2, 512, 128])
  onnx_R[:10]: te