# TensorRT 推理优化

**SOTA 教育标准** | 包含 TensorRT 基础、INT8 量化、性能优化

---

## 1. TensorRT 概述

**TensorRT**: NVIDIA 高性能深度学习推理优化器

| 优化 | 说明 | 加速 |
|:-----|:-----|:----:|
| **层融合** | Conv+BN+ReLU → 单kernel | 1.5-2x |
| **精度校准** | FP32 → FP16/INT8 | 2-4x |

In [None]:
from __future__ import annotations
from dataclasses import dataclass
from typing import Optional, Tuple
import time
import numpy as np
import torch
import torch.nn as nn
import matplotlib.pyplot as plt

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

TRT_AVAILABLE = False
try:
    import tensorrt as trt
    print(f"TensorRT: {trt.__version__}")
    TRT_AVAILABLE = True
except ImportError:
    print("TensorRT not installed")

---

## 2. TensorRT 配置

In [None]:
@dataclass
class TensorRTConfig:
    """TensorRT 构建配置。"""
    precision: str = "fp16"  # fp32, fp16, int8
    max_batch_size: int = 16
    max_workspace_size: int = 1 << 30  # 1GB


class TensorRTBuilder:
    """TensorRT 引擎构建器。
    
    Core Idea: 将 ONNX 模型转换为优化的 TensorRT 引擎。
    """

    def __init__(self, config: TensorRTConfig = TensorRTConfig()):
        self.config = config
        if TRT_AVAILABLE:
            self.logger = trt.Logger(trt.Logger.WARNING)
            self.builder = trt.Builder(self.logger)

    def build_from_onnx(self, onnx_path: str) -> Optional[bytes]:
        """从 ONNX 构建 TensorRT 引擎。"""
        if not TRT_AVAILABLE:
            print("TensorRT not available")
            return None
        print(f"Building TensorRT engine from {onnx_path}")
        return None  # 实际实现需要完整的 TensorRT API


# 测试
builder = TensorRTBuilder()
print(f"配置: {builder.config}")

---

## 3. INT8 量化校准

In [None]:
class INT8Calibrator:
    """INT8 量化校准器。
    
    Core Idea: 使用校准数据集确定最优的量化参数。
    """

    def __init__(self, data: np.ndarray, batch_size: int = 32):
        self.data = data
        self.batch_size = batch_size
        self.current_idx = 0

    def get_batch(self) -> Optional[np.ndarray]:
        """获取下一批校准数据。"""
        if self.current_idx >= len(self.data):
            return None
        end = min(self.current_idx + self.batch_size, len(self.data))
        batch = self.data[self.current_idx:end]
        self.current_idx = end
        return batch

    def reset(self):
        self.current_idx = 0


# 测试
cal_data = np.random.randn(100, 3, 32, 32).astype(np.float32)
calibrator = INT8Calibrator(cal_data)
batch = calibrator.get_batch()
print(f"校准批次形状: {batch.shape}")

---

## 4. INT8 量化模拟

In [None]:
def simulate_int8(tensor: np.ndarray, method: str = "minmax") -> Tuple[np.ndarray, float]:
    """模拟 INT8 量化。"""
    if method == "minmax":
        scale = max(abs(tensor.min()), abs(tensor.max())) / 127
    else:  # percentile
        scale = np.percentile(np.abs(tensor), 99.9) / 127
    
    quantized = np.clip(np.round(tensor / scale), -128, 127).astype(np.int8)
    dequantized = quantized.astype(np.float32) * scale
    error = np.abs(tensor - dequantized).mean()
    return dequantized, error


# 测试
activations = np.random.randn(1000).astype(np.float32)
for method in ["minmax", "percentile"]:
    _, error = simulate_int8(activations, method)
    print(f"{method}: Mean Error = {error:.6f}")

---

## 5. 性能可视化

In [None]:
def visualize_tensorrt_speedup() -> None:
    """可视化 TensorRT 加速效果。"""
    models = ["ResNet-50", "VGG-16", "BERT", "YOLOv5"]
    pytorch = [15.2, 22.5, 45.0, 12.0]
    trt_fp16 = [4.2, 6.5, 15.0, 4.0]
    trt_int8 = [2.8, 4.2, 10.0, 2.5]

    x = np.arange(len(models))
    fig, ax = plt.subplots(figsize=(10, 5))
    
    ax.bar(x - 0.25, pytorch, 0.25, label='PyTorch', color='blue')
    ax.bar(x, trt_fp16, 0.25, label='TRT FP16', color='orange')
    ax.bar(x + 0.25, trt_int8, 0.25, label='TRT INT8', color='red')
    
    ax.set_xlabel('Model')
    ax.set_ylabel('Latency (ms)')
    ax.set_title('TensorRT Speedup')
    ax.set_xticks(x)
    ax.set_xticklabels(models)
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()

    for i, m in enumerate(models):
        print(f"{m}: PyTorch {pytorch[i]}ms → INT8 {trt_int8[i]}ms ({pytorch[i]/trt_int8[i]:.1f}x)")


visualize_tensorrt_speedup()

---

## 6. 总结

| 精度 | 加速 | 精度损失 | 适用场景 |
|:-----|:----:|:--------:|:---------|
| **FP16** | 2-3x | <0.1% | 通用推理 |
| **INT8** | 3-5x | <1% | 高吞吐量 |