## 导入包

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.onnx

## model

In [2]:
# 加载 AlexNet
model = models.googlenet()
print(model)



GoogLeNet(
  (conv1): BasicConv2d(
    (conv): Conv2d(3, 64, kernel_size=(7, 7), stride=(2, 2), padding=(3, 3), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool1): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (conv2): BasicConv2d(
    (conv): Conv2d(64, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
    (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (conv3): BasicConv2d(
    (conv): Conv2d(64, 192, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
    (bn): BatchNorm2d(192, eps=0.001, momentum=0.1, affine=True, track_running_stats=True)
  )
  (maxpool2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=True)
  (inception3a): Inception(
    (branch1): BasicConv2d(
      (conv): Conv2d(192, 64, kernel_size=(1, 1), stride=(1, 1), bias=False)
      (bn): BatchNorm2d(64, eps=0.001, momentum=0.1, affine=True, track

## onnx

In [3]:
# 设置模型为评估模式
model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出模型为 ONNX 格式
output_file = "googlenet.onnx"
torch.onnx.export(model, dummy_input, output_file,
                  export_params=True,        # 存储训练过的参数
                  opset_version=10,         # ONNX 版本
                  do_constant_folding=True, # 是否执行常量折叠优化
                  input_names=['input'],    # 输入名称
                  output_names=['output'],  # 输出名称
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 批次大小动态
                  )

print(f"ONNX model exported to {output_file}")

ONNX model exported to googlenet.onnx


## onnxruntime

In [4]:
import onnxruntime
import numpy as np

# 加载 ONNX 运行时
ort_session = onnxruntime.InferenceSession(output_file)

# 创建输入数据
input_name = ort_session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行模型
ort_outputs = ort_session.run(None, {input_name: input_data})

print("ONNX Runtime output:")
print(ort_outputs)

ONNX Runtime output:
[array([[ 7.33307377e-03, -2.02086456e-02, -5.65131009e-03,
         1.30031072e-02,  2.57856064e-02, -1.13751739e-02,
         1.02943406e-02, -1.22817382e-02, -1.51564702e-02,
         7.09834695e-03, -2.74044089e-02,  9.06753168e-03,
        -1.12390891e-02,  2.15312578e-02,  2.73788720e-03,
         1.92185342e-02,  1.41014084e-02,  1.36813745e-02,
         1.64844505e-02, -4.30859625e-03,  1.36780366e-03,
         2.23519094e-02,  1.53798610e-04,  2.28636749e-02,
        -2.27894709e-02,  3.07538956e-02, -2.75636688e-02,
        -2.19397023e-02, -1.21418759e-03,  2.85535417e-02,
        -2.40852349e-02, -2.13438720e-02, -2.79413462e-02,
        -6.14691898e-03,  9.13581625e-03,  2.38153189e-02,
         1.17010400e-02,  2.08196081e-02,  3.78814340e-03,
         2.69729570e-02, -1.87085643e-02, -1.79577991e-02,
        -7.63315707e-04, -1.66812912e-03, -2.25032531e-02,
        -3.54301557e-03, -2.42265165e-02, -1.17831305e-02,
        -1.40793249e-02,  2.686384

## netron可视化

In [5]:
import netron


# 使用 netron 查看 ONNX 模型
netron.start(output_file)

Serving 'googlenet.onnx' at http://localhost:21633


('localhost', 21633)