## 导入包

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.onnx

## model

In [2]:
# 加载 AlexNet
model = models.vgg16()
print(model)

VGG(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (1): ReLU(inplace=True)
    (2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (3): ReLU(inplace=True)
    (4): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (5): Conv2d(64, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (6): ReLU(inplace=True)
    (7): Conv2d(128, 128, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (8): ReLU(inplace=True)
    (9): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1, ceil_mode=False)
    (10): Conv2d(128, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (13): ReLU(inplace=True)
    (14): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (15): ReLU(inplace=True)
    (16): MaxPool2d(kernel_size=2, stride=2, padding=0, dilation=1

## onnx

In [3]:
# 设置模型为评估模式
model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出模型为 ONNX 格式
output_file = "vgg16.onnx"
torch.onnx.export(model, dummy_input, output_file,
                  export_params=True,        # 存储训练过的参数
                  opset_version=10,         # ONNX 版本
                  do_constant_folding=True, # 是否执行常量折叠优化
                  input_names=['input'],    # 输入名称
                  output_names=['output'],  # 输出名称
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 批次大小动态
                  )

print(f"ONNX model exported to {output_file}")

ONNX model exported to vgg16.onnx


## onnxruntime

In [4]:
import onnxruntime
import numpy as np

# 加载 ONNX 运行时
ort_session = onnxruntime.InferenceSession(output_file)

# 创建输入数据
input_name = ort_session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行模型
ort_outputs = ort_session.run(None, {input_name: input_data})

print("ONNX Runtime output:")
print(ort_outputs)

ONNX Runtime output:
[array([[ 3.55565995e-02,  2.57675238e-02, -1.01667389e-01,
        -1.74680233e-01,  1.41171873e-01, -5.32209650e-02,
         1.45709151e-02,  4.45636548e-03, -4.29428518e-02,
        -2.99522318e-02,  7.91202039e-02, -6.23151213e-02,
        -6.25880882e-02, -8.84922147e-02,  1.22570340e-02,
         1.18463952e-03, -9.60388780e-03, -8.03163797e-02,
         1.34690881e-01,  6.58621937e-02, -2.35044956e-02,
        -1.97143406e-02,  7.47760385e-02,  9.18803141e-02,
        -8.66966322e-03,  3.21950689e-02,  2.64817663e-02,
        -2.73056161e-02, -4.65572923e-02, -9.43469331e-02,
        -9.48418081e-02, -7.15439245e-02,  6.22797757e-04,
         2.49201655e-02, -7.23188892e-02, -1.25283841e-02,
         1.28529578e-01, -1.68312714e-02, -3.83148566e-02,
         1.99033711e-02, -5.45041412e-02, -6.44017532e-02,
         5.93201555e-02, -8.12561363e-02, -4.33836207e-02,
         1.52437966e-02, -3.09907291e-02, -5.04755694e-03,
         1.31976336e-01,  1.441867

## netron可视化

In [5]:
import netron


# 使用 netron 查看 ONNX 模型
netron.start(output_file)

Serving 'vgg16.onnx' at http://localhost:23791


('localhost', 23791)