In [1]:
import os
os.chdir('../')  # 更改notebook的工作路径到上一级目录

In [2]:
import cv2
import numpy as np

import torch
import torch.nn.functional as F
from torchvision import transforms

from model.fcos import FCOSDetector
from model.config import DefaultConfig

## 模型转换之Pytorch --> ONNX

In [3]:
# 首先建立一个新的FCOSDetectorDeploy，目的是在forward里增加一个deploy模式
class FCOSDetectorDeploy(FCOSDetector):
    def __init__(self,mode="training",config=None):
        super().__init__(mode=mode, config=config)

    def forward(self,inputs):
        '''
        inputs
        [training] list  batch_imgs,batch_boxes,batch_classes
        [inference] img
        '''

        if self.mode=="training":
            batch_imgs,batch_boxes,batch_classes=inputs
            out=self.fcos_body(batch_imgs)
            targets=self.target_layer([out,batch_boxes,batch_classes])
            losses=self.loss_layer([out,targets])
            return losses
        elif self.mode=="inference":
            # raise NotImplementedError("no implement inference model")
            '''
            for inference mode, img should preprocessed before feeding in net
            '''
            batch_imgs=inputs
            out=self.fcos_body(batch_imgs)
            scores,classes,boxes=self.detection_head(out)
            boxes=self.clip_boxes(batch_imgs,boxes)
            return scores,classes,boxes

        # 相对于原来的FCOSDetector，增加了deploy模式，相比之前的inference，
        # deploy模式的输出是三个concat后的tensor。同样，为了将计算尽可能保留
        # 在模型内，这里的forward提前调用的sigmoid()函数
        elif self.mode=="deploy":
            
            out = self.fcos_body(inputs)
            # Step 1. Concat Cls Output, cls_out shape [1,20,#anchor points]
            cls_out = F.sigmoid(torch.cat([out[0][i].view(1,20,-1) for i in range(len(out[0]))], -1))
            # Step 2. Concat Cnt Output, cnt_out shape [1,1,#anchor points]
            cnt_out = F.sigmoid(torch.cat([out[1][i].view(1,1,-1) for i in range(len(out[1]))], -1))
            # Step 3. Concat Reg Output, reg_out shape [1,4,#anchor points]
            reg_out = torch.cat([out[2][i].view(1,4,-1) for i in range(len(out[2]))], -1)
            
            return cls_out,cnt_out,reg_out

In [25]:
# 以deploy模式建立一个FCOS检测器
model = FCOSDetectorDeploy(mode="deploy", config=DefaultConfig)
model = torch.nn.DataParallel(model)

# 加载训练好的模型，并开启eval()模式
model.load_state_dict(torch.load("./training_dir_fp16_nogn/model_24.pth",map_location=torch.device('cpu')))
model = model.eval()

# 建立dummy_input，固定推理阶段输入图片的长和宽
inference_h = 320
inference_w = 320
dummy_input = torch.randn(1, 3, inference_h, inference_w)

# 确定模型的输入名称和输出节点名称
input_names = ['img']
output_names = ['cls', 'cnt', 'reg']

# 转换模型
name_converted = 'fcos.onnx'
torch.onnx.export(model.module, dummy_input, name_converted, verbose=True, input_names=input_names, output_names=output_names)

INFO: using darnet19 backbone




graph(%img : Float(1, 3, 320, 320),
      %fcos_body.backbone.block1.0.weight : Float(32, 3, 3, 3),
      %fcos_body.backbone.block1.1.weight : Float(32),
      %fcos_body.backbone.block1.1.bias : Float(32),
      %fcos_body.backbone.block1.1.running_mean : Float(32),
      %fcos_body.backbone.block1.1.running_var : Float(32),
      %fcos_body.backbone.block1.1.num_batches_tracked : Long(),
      %fcos_body.backbone.block1.4.weight : Float(64, 32, 3, 3),
      %fcos_body.backbone.block1.5.weight : Float(64),
      %fcos_body.backbone.block1.5.bias : Float(64),
      %fcos_body.backbone.block1.5.running_mean : Float(64),
      %fcos_body.backbone.block1.5.running_var : Float(64),
      %fcos_body.backbone.block1.5.num_batches_tracked : Long(),
      %fcos_body.backbone.block1.8.weight : Float(128, 64, 3, 3),
      %fcos_body.backbone.block1.9.weight : Float(128),
      %fcos_body.backbone.block1.9.bias : Float(128),
      %fcos_body.backbone.block1.9.running_mean : Float(128),
      %fc

## 验证模型转换精度1--用原始的Pytorch推理图片

In [22]:
# 读取图片
img_name = 'test_images/000004.jpg'
img = cv2.imread(img_name)

# 需要将模型转换到float32
img = cv2.resize(img, (320,320)).astype('float32')

# 将模型从[320,320,3]变成[1,3,320,320]的推理形状
img = np.transpose(img, (2,0,1))[None,...]
img = torch.tensor(img)

# 以deploy模式建立FCOS检测器
model = FCOSDetectorDeploy(mode="deploy", config=DefaultConfig)
model = torch.nn.DataParallel(model)

# 加载预训练模型
model.load_state_dict(torch.load("./training_dir_fp16_nogn/model_24.pth",map_location=torch.device('cpu')))
model = model.eval()

# 拿到分类，中心度和回归的输出，并打印回归的输出方便与onnx的结果做对比
with torch.no_grad():
    cls1, cnt1, reg1 = model(img)
    print(reg1)

INFO: using darnet19 backbone
tensor([[[3.5129e+02, 1.0395e+04, 8.7255e+03,  ..., 1.8458e+02,
          7.5282e+02, 9.8308e+02],
         [1.3531e+02, 3.2338e+02, 8.0767e+02,  ..., 6.3049e+05,
          2.0588e+05, 5.4898e+02],
         [1.4481e+03, 3.8759e+04, 7.5060e+05,  ..., 3.0634e+03,
          1.7337e+02, 4.6023e+00],
         [2.1421e+01, 3.4048e+00, 1.0136e+00,  ..., 1.4374e+02,
          1.3487e+01, 4.3831e+00]]])




## 验证模型转换精度2--用ONNX推理图片

In [10]:
# 导入所需的onnxruntime包
# 如果没有这个包，则安装: pip install onnxruntime
import onnxruntime

In [23]:
# 读取图片
img_name = 'test_images/000004.jpg'
img = cv2.imread(img_name)

# 需要将模型转换到float32，否则无法推理
img = cv2.resize(img, (320,320)).astype('float32')

# 将模型从[320,320,3]变成[1,3,320,320]的推理格式
img = np.transpose(img, (2,0,1))[None,...]

# 使用onnxruntime加载转换好的fcos检测器
sess = onnxruntime.InferenceSession(name_converted, None)

# 使用onnxruntime进行推理，并打印回归输出
(cls2, cnt2, reg2) = sess.run(None, {'img': img})
print(reg2)

[[[3.51289093e+02 1.03950938e+04 8.72563379e+03 ... 1.84585175e+02
   7.52820618e+02 9.83078430e+02]
  [1.35305420e+02 3.23380798e+02 8.07670959e+02 ... 6.30485938e+05
   2.05873422e+05 5.48978455e+02]
  [1.44809155e+03 3.87590586e+04 7.50608750e+05 ... 3.06335059e+03
   1.73369644e+02 4.60233307e+00]
  [2.14211693e+01 3.40483952e+00 1.01360846e+00 ... 1.43739243e+02
   1.34865055e+01 4.38306761e+00]]]


In [21]:
# 观察Pytorch到onnx在回归输出上的转换误差
reg1.numpy() - reg2

array([[[ 4.8828125e-04, -4.8828125e-02, -9.9609375e-02, ...,
         -2.5939941e-04,  3.2348633e-03, -2.3803711e-03],
        [ 7.7819824e-04, -6.4086914e-04, -4.6386719e-03, ...,
          3.0000000e+00,  2.1718750e+00, -4.8828125e-04],
        [ 5.4931641e-03, -1.8750000e-01, -7.8750000e+00, ...,
          8.7890625e-03,  5.7983398e-04,  2.3841858e-06],
        [ 9.1552734e-05,  9.5367432e-06,  4.0531158e-06, ...,
          6.2561035e-04, -3.8146973e-06, -2.3841858e-06]]], dtype=float32)