## 导入包

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.onnx

## model

In [2]:
# 加载 AlexNet
model = models.resnet18()
print(model)

ResNet(
  (conv1): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
  (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
  (relu): ReLU(inplace=True)
  (maxpool): MaxPool2d(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False)
  (layer1): Sequential(
    (0): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
      (conv2): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn2): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
    )
    (1): BasicBlock(
      (conv1): Conv2d(64, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
      (bn1): BatchNorm2d(64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True)
      (relu): ReLU(inplace=True)
  

## onnx

In [3]:
# 设置模型为评估模式
model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出模型为 ONNX 格式
output_file = "resnet18.onnx"
torch.onnx.export(model, dummy_input, output_file,
                  export_params=True,        # 存储训练过的参数
                  opset_version=10,         # ONNX 版本
                  do_constant_folding=True, # 是否执行常量折叠优化
                  input_names=['input'],    # 输入名称
                  output_names=['output'],  # 输出名称
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 批次大小动态
                  )

print(f"ONNX model exported to {output_file}")

ONNX model exported to resnet18.onnx


## onnxruntime

In [4]:
import onnxruntime
import numpy as np

# 加载 ONNX 运行时
ort_session = onnxruntime.InferenceSession(output_file)

# 创建输入数据
input_name = ort_session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行模型
ort_outputs = ort_session.run(None, {input_name: input_data})

print("ONNX Runtime output:")
print(ort_outputs)

ONNX Runtime output:
[array([[-1.66793180e+00, -3.92353594e-01,  1.11649334e+00,
         1.66349399e+00, -9.18997824e-01,  1.78590298e-01,
        -3.07256508e+00,  6.82803094e-01, -2.50629354e+00,
        -2.68054819e+00,  3.26375782e-01,  1.73960972e+00,
         1.95561171e+00,  2.89385021e-01, -6.92295969e-01,
        -1.38209492e-01, -1.47281170e-01, -1.68793511e+00,
        -2.57358122e+00,  2.80871201e+00, -1.89154863e-01,
         1.78896725e+00,  4.19847310e-01,  3.33107162e+00,
         4.97488976e-01, -3.41463089e+00,  1.56996310e-01,
        -1.06604755e-01, -2.86579132e-04, -1.03322411e+00,
         2.38066626e+00, -3.47520638e+00,  1.31227756e+00,
        -1.29002976e+00,  1.41867256e+00,  3.82627249e-01,
         1.74236298e-02, -7.89529502e-01,  8.61190677e-01,
        -1.04914808e+00, -3.46375036e+00, -1.28393745e+00,
        -1.92072320e+00,  2.05425501e+00, -1.95799232e+00,
         4.14558351e-01,  2.29270411e+00,  4.48581791e+00,
         1.25039089e+00, -1.740506

## netron可视化

In [5]:
import netron


# 使用 netron 查看 ONNX 模型
netron.start(output_file)

Serving 'resnet18.onnx' at http://localhost:23569


('localhost', 23569)