# 1. 加载模型，导出ONNX

In [2]:
import yaml
import torch.onnx
# project 
import sys 
sys.path.append("..") 
import archs

model_name = 'wrist'

# Obtain your model, it can be also constructed in your script explicitly
with open('../models/%s/config.yml' % model_name, 'r') as f:
    config = yaml.load(f, Loader=yaml.FullLoader)

print("=> creating model %s" % config['arch'])
model = archs.__dict__[config['arch']](config['num_classes'],
                                       config['input_channels'],
                                       config['deep_supervision'])

model.load_state_dict(torch.load('../models/%s/model.pth' % config['name']))
model.eval()
# UNext input - 3 channels, 512x512,
# values don't matter as we care about network structure.
# But they can also be real inputs.
dummy_input = torch.randn(1, 3, 512, 512, requires_grad=True)
torch_out = model(dummy_input)

# Invoke export
torch.onnx.export(model, dummy_input, model_name + ".onnx")

=> creating model UNext


# 2. 检验 ONNX 模型

In [4]:
import onnx
# 我们可以使用异常处理的方法进行检验
try:
    # 当我们的模型不可用时，将会报出异常
    onnx.checker.check_model("wrist.onnx")
except onnx.checker.ValidationError as e:
    print("The model is invalid: %s"%e)
else:
    # 模型可用时，将不会报出异常，并会输出“The model is valid!”
    print("The model is valid!")

The model is valid!


# 3. 使用 ONNX Runtime 进行推理

In [6]:
import onnxruntime
import numpy as np

ort_session = onnxruntime.InferenceSession("wrist.onnx")

# 将张量转化为ndarray格式
def to_numpy(tensor):
    return tensor.detach().cpu().numpy() if tensor.requires_grad else tensor.cpu().numpy()

# 构建输入的字典和计算输出结果
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(dummy_input)}
ort_outs = ort_session.run(None, ort_inputs)

# 比较使用PyTorch和ONNX Runtime得出的精度
np.testing.assert_allclose(to_numpy(torch_out), ort_outs[0], rtol=1e-03, atol=1e-05)

print("Exported model has been tested with ONNXRuntime, and the result looks good!")

Exported model has been tested with ONNXRuntime, and the result looks good!


# 4. 进行实际预测并可视化

In [19]:
import cv2
from PIL import Image
import albumentations as A
from albumentations.core.composition import Compose

# 读取图片
img = cv2.imread('000002.png')
# 对图片进行resize操作
val_transform = Compose([
    A.Resize(config['input_h'], config['input_w']),
    A.Normalize(),
])

# image shape
print(img.shape)
img_h, img_w, _ = img.shape

# preprocess img-[1,3,512,512]
img = val_transform(image=img)['image']
img = img.astype('float32') / 255
img = img.transpose(2, 0, 1)
img = torch.from_numpy(img).unsqueeze(0)
# 构建输入的字典并将value转换位array格式
ort_inputs = {ort_session.get_inputs()[0].name: to_numpy(img)}
ort_outs = ort_session.run(None, ort_inputs)
print(ort_outs[0].shape)
img_out = ort_outs[0]
img_out = Image.fromarray(np.uint8((img_out[0] * 255.0).clip(0, 255)[0]), mode='L')
img_out.save('mask.jpg')

(720, 1280, 3)
(1, 1, 512, 512)
