# ONNX 模型导出与优化

**SOTA 教育标准** | 包含 ONNX 导出、验证、优化、推理

---

## 1. ONNX 概述

**ONNX**: 开放神经网络交换格式，支持跨框架模型转换。

| 优势 | 说明 |
|:-----|:-----|
| **跨框架** | PyTorch/TF → ONNX |
| **优化推理** | ONNX Runtime 加速 |

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Dict, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

ONNX_AVAILABLE = False
try:
    import onnx
    import onnxruntime as ort
    print(f"ONNX: {onnx.__version__}, Runtime: {ort.__version__}")
    ONNX_AVAILABLE = True
except ImportError:
    print("ONNX not installed. Run: pip install onnx onnxruntime")

---

## 2. 模型定义与导出配置

In [None]:
@dataclass
class ONNXExportConfig:
    """ONNX 导出配置。"""
    opset_version: int = 14
    dynamic_axes: bool = True
    do_constant_folding: bool = True


class SimpleCNN(nn.Module):
    """用于导出演示的简单 CNN。"""
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(), nn.MaxPool2d(2))
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, num_classes))

    def forward(self, x: Tensor) -> Tensor:
        return self.classifier(self.features(x))


model = SimpleCNN(10)
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

---

## 3. ONNX 导出

In [None]:
def export_to_onnx(model: nn.Module, input_shape: Tuple[int, ...], output_path: str,
                   config: ONNXExportConfig = ONNXExportConfig()) -> str:
    """导出 PyTorch 模型到 ONNX。"""
    model.eval()
    dummy_input = torch.randn(*input_shape)
    
    dynamic_axes = {"input": {0: "batch"}, "output": {0: "batch"}} if config.dynamic_axes else None
    
    torch.onnx.export(
        model, dummy_input, output_path,
        export_params=True, opset_version=config.opset_version,
        do_constant_folding=config.do_constant_folding,
        input_names=["input"], output_names=["output"],
        dynamic_axes=dynamic_axes)
    
    print(f"Exported to: {output_path}")
    return output_path


if ONNX_AVAILABLE:
    onnx_path = export_to_onnx(model, (1, 3, 32, 32), "/tmp/model.onnx")
else:
    print("跳过导出")

---

## 4. ONNX 验证

In [None]:
def verify_onnx_model(onnx_path: str) -> bool:
    """验证 ONNX 模型。"""
    if not ONNX_AVAILABLE:
        return False
    try:
        onnx_model = onnx.load(onnx_path)
        onnx.checker.check_model(onnx_model)
        print(f"Model valid! IR: {onnx_model.ir_version}, Opset: {onnx_model.opset_import[0].version}")
        return True
    except Exception as e:
        print(f"Validation failed: {e}")
        return False


if ONNX_AVAILABLE:
    verify_onnx_model("/tmp/model.onnx")

---

## 5. ONNX Runtime 推理

In [None]:
class ONNXEngine:
    """ONNX Runtime 推理引擎。"""
    def __init__(self, onnx_path: str):
        if not ONNX_AVAILABLE:
            raise RuntimeError("ONNX Runtime not available")
        self.session = ort.InferenceSession(onnx_path, providers=["CPUExecutionProvider"])
        self.input_name = self.session.get_inputs()[0].name
        self.output_name = self.session.get_outputs()[0].name

    def __call__(self, inputs: np.ndarray) -> np.ndarray:
        return self.session.run([self.output_name], {self.input_name: inputs})[0]


if ONNX_AVAILABLE:
    engine = ONNXEngine("/tmp/model.onnx")
    test_input = np.random.randn(1, 3, 32, 32).astype(np.float32)
    output = engine(test_input)
    print(f"ONNX 推理输出: {output.shape}")

---

## 6. 性能对比

In [None]:
import time

def benchmark(model: nn.Module, onnx_path: str, input_shape: Tuple, runs: int = 100) -> Dict:
    """对比 PyTorch 和 ONNX Runtime 性能。"""
    dummy = torch.randn(*input_shape)
    model.eval()
    
    # PyTorch
    with torch.no_grad():
        for _ in range(10): model(dummy)  # warmup
        start = time.perf_counter()
        for _ in range(runs): model(dummy)
        pt_time = (time.perf_counter() - start) / runs * 1000
    
    # ONNX
    if ONNX_AVAILABLE:
        engine = ONNXEngine(onnx_path)
        np_input = dummy.numpy()
        for _ in range(10): engine(np_input)  # warmup
        start = time.perf_counter()
        for _ in range(runs): engine(np_input)
        ort_time = (time.perf_counter() - start) / runs * 1000
    else:
        ort_time = float('inf')
    
    return {"pytorch_ms": pt_time, "onnx_ms": ort_time, "speedup": pt_time / ort_time}


if ONNX_AVAILABLE:
    results = benchmark(model, "/tmp/model.onnx", (1, 3, 32, 32))
    print(f"PyTorch: {results['pytorch_ms']:.2f}ms, ONNX: {results['onnx_ms']:.2f}ms, Speedup: {results['speedup']:.2f}x")

---

## 7. 总结

| 步骤 | 函数 | 说明 |
|:-----|:-----|:-----|
| **导出** | `torch.onnx.export()` | PyTorch → ONNX |
| **验证** | `onnx.checker.check_model()` | 检查有效性 |
| **推理** | `ort.InferenceSession()` | ONNX Runtime |