In [3]:
import argparse
import os

import onnx
import onnxsim
import torch

from nanodet.model.arch import build_model
from nanodet.util import Logger, cfg, load_config, load_model_weight 

In [4]:
def generate_ouput_names(head_cfg):
    cls_names, dis_names = [], []
    for stride in head_cfg.strides:
        cls_names.append("cls_pred_stride_{}".format(stride))
        dis_names.append("dis_pred_stride_{}".format(stride))
    return cls_names + dis_names

In [5]:
def main(config, model_path, output_path, input_shape=(320, 320)):
    logger = Logger(-1, config.save_dir, False)
    model = build_model(config.model)
    checkpoint = torch.load(model_path, map_location=lambda storage, loc: storage)
    load_model_weight(model, checkpoint, logger)
    if config.model.arch.backbone.name == "RepVGG":
        deploy_config = config.model
        deploy_config.arch.backbone.update({"deploy": True})
        deploy_model = build_model(deploy_config)
        from nanodet.model.backbone.repvgg import repvgg_det_model_convert

        model = repvgg_det_model_convert(model, deploy_model)
    dummy_input = torch.autograd.Variable(
        torch.randn(1, 3, input_shape[0], input_shape[1])
    )

    torch.onnx.export(
        model,
        dummy_input,
        output_path,
        verbose=True,
        keep_initializers_as_inputs=True,
        opset_version=11,
        input_names=["data"],
        output_names=["output"],
    )
    logger.log("finished exporting onnx ")

    logger.log("start simplifying onnx ")
    input_data = {"data": dummy_input.detach().cpu().numpy()}
    model_sim, flag = onnxsim.simplify(output_path, input_data=input_data)
    if flag:
        onnx.save(model_sim, output_path)
        logger.log("simplify onnx successfully")
    else:
        logger.log("simplify onnx failed")

In [9]:

cfg_path = "nanodet-plus-m_320.yml"
model_path = "model_best.ckpt"
out_path = 'out.onnx'
input_shape = None
load_config(cfg, cfg_path)
if input_shape is None:
        input_shape = cfg.data.train.input_size
else:
        input_shape = tuple(map(int, input_shape.split(",")))
        assert len(input_shape) == 2
if model_path is None:
        model_path = os.path.join(cfg.save_dir, "model_best/model_best.ckpt")
main(cfg, model_path, out_path, input_shape)
print("Model saved to:", out_path)

model size is  1.0x
init weights...
=> loading pretrained model https://download.pytorch.org/models/shufflenetv2_x1-5666bf0f80.pth
Finish initialize NanoDet-Plus Head.


  channels_per_group = num_channels // groups
[1m[35m[root][0m[34m[04-25 08:31:03][0m[32mINFO:[0m[37mfinished exporting onnx [0m
[1m[35m[root][0m[34m[04-25 08:31:03][0m[32mINFO:[0m[37mfinished exporting onnx [0m
[1m[35m[root][0m[34m[04-25 08:31:03][0m[32mINFO:[0m[37mstart simplifying onnx [0m
[1m[35m[root][0m[34m[04-25 08:31:03][0m[32mINFO:[0m[37mstart simplifying onnx [0m


graph(%data : Float(1, 3, 320, 320, strides=[307200, 102400, 320, 1], requires_grad=0, device=cpu),
      %head.gfl_cls.0.weight : Float(34, 96, 1, 1, strides=[96, 1, 1, 1], requires_grad=1, device=cpu),
      %head.gfl_cls.0.bias : Float(34, strides=[1], requires_grad=1, device=cpu),
      %head.gfl_cls.1.weight : Float(34, 96, 1, 1, strides=[96, 1, 1, 1], requires_grad=1, device=cpu),
      %head.gfl_cls.1.bias : Float(34, strides=[1], requires_grad=1, device=cpu),
      %head.gfl_cls.2.weight : Float(34, 96, 1, 1, strides=[96, 1, 1, 1], requires_grad=1, device=cpu),
      %head.gfl_cls.2.bias : Float(34, strides=[1], requires_grad=1, device=cpu),
      %head.gfl_cls.3.weight : Float(34, 96, 1, 1, strides=[96, 1, 1, 1], requires_grad=1, device=cpu),
      %head.gfl_cls.3.bias : Float(34, strides=[1], requires_grad=1, device=cpu),
      %1532 : Float(24, 3, 3, 3, strides=[27, 9, 3, 1], requires_grad=0, device=cpu),
      %1533 : Float(24, strides=[1], requires_grad=0, device=cpu),
   

[1m[35m[root][0m[34m[04-25 08:31:04][0m[32mINFO:[0m[37msimplify onnx successfully[0m
[1m[35m[root][0m[34m[04-25 08:31:04][0m[32mINFO:[0m[37msimplify onnx successfully[0m


Model saved to: out.onnx
