## 导入包

In [1]:
import torch
import torch.nn as nn
import torch.nn.functional as F
from torchvision import models
import torch.onnx

## model

In [2]:
# 加载 AlexNet
model = models.alexnet()
print(model)

AlexNet(
  (features): Sequential(
    (0): Conv2d(3, 64, kernel_size=(11, 11), stride=(4, 4), padding=(2, 2))
    (1): ReLU(inplace=True)
    (2): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (3): Conv2d(64, 192, kernel_size=(5, 5), stride=(1, 1), padding=(2, 2))
    (4): ReLU(inplace=True)
    (5): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
    (6): Conv2d(192, 384, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (7): ReLU(inplace=True)
    (8): Conv2d(384, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (9): ReLU(inplace=True)
    (10): Conv2d(256, 256, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1))
    (11): ReLU(inplace=True)
    (12): MaxPool2d(kernel_size=3, stride=2, padding=0, dilation=1, ceil_mode=False)
  )
  (avgpool): AdaptiveAvgPool2d(output_size=(6, 6))
  (classifier): Sequential(
    (0): Dropout(p=0.5, inplace=False)
    (1): Linear(in_features=9216, out_features=4096, bias=True)
 

## onnx

In [3]:
# 设置模型为评估模式
model.eval()

# 创建示例输入
dummy_input = torch.randn(1, 3, 224, 224)

# 导出模型为 ONNX 格式
output_file = "alexnet.onnx"
torch.onnx.export(model, dummy_input, output_file,
                  export_params=True,        # 存储训练过的参数
                  opset_version=10,         # ONNX 版本
                  do_constant_folding=True, # 是否执行常量折叠优化
                  input_names=['input'],    # 输入名称
                  output_names=['output'],  # 输出名称
                  dynamic_axes={'input': {0: 'batch_size'}, 'output': {0: 'batch_size'}} # 批次大小动态
                  )

print(f"ONNX model exported to {output_file}")

ONNX model exported to alexnet.onnx


## onnxruntime

In [4]:
import onnxruntime
import numpy as np

# 加载 ONNX 运行时
ort_session = onnxruntime.InferenceSession(output_file)

# 创建输入数据
input_name = ort_session.get_inputs()[0].name
input_data = np.random.randn(1, 3, 224, 224).astype(np.float32)

# 运行模型
ort_outputs = ort_session.run(None, {input_name: input_data})

print("ONNX Runtime output:")
print(ort_outputs)

ONNX Runtime output:
[array([[-1.06881629e-03, -7.13833096e-03, -1.46302106e-02,
        -1.48632266e-02, -2.81404471e-03,  3.25555890e-03,
         1.21596931e-02, -4.27395152e-03,  1.00991139e-02,
        -7.87718780e-03, -1.51479086e-02, -1.29680932e-02,
         7.32187321e-03, -1.83659866e-02, -1.00143384e-02,
         5.42495959e-03,  1.56334359e-02,  3.27218138e-03,
         1.94164459e-02, -5.10004023e-03,  6.08730176e-03,
         8.51447973e-03,  1.20760286e-02,  1.84825074e-03,
         1.18431775e-02, -1.31666902e-02,  1.20003102e-02,
         1.68855395e-02, -2.58168038e-02,  2.72929203e-03,
        -7.58403912e-03, -3.09484825e-03, -4.63490468e-03,
         7.04932446e-03, -1.33286892e-02, -1.33502856e-02,
        -1.04321260e-02, -9.03326645e-03, -7.35123223e-03,
         2.82454723e-03,  1.07265962e-02,  8.27054586e-03,
         5.58929937e-03, -1.26117337e-02, -1.56792793e-02,
         2.49235379e-03,  9.34864581e-03,  3.77762201e-03,
         7.70594086e-03,  4.797847

## netron可视化

In [5]:
import netron


# 使用 netron 查看 ONNX 模型
netron.start(output_file)

Serving 'alexnet.onnx' at http://localhost:8081


('localhost', 8081)