# TorchScript 部署

**SOTA 教育标准** | 包含 Tracing、Scripting、JIT 编译优化

---

## 1. TorchScript 概述

| 方式 | 原理 | 适用场景 |
|:-----|:-----|:---------|
| **Tracing** | 记录执行路径 | 无控制流模型 |
| **Scripting** | 解析代码 | 有条件分支 |

In [None]:
from __future__ import annotations
import time
from typing import Dict, Optional
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch.nn as nn
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}, PyTorch: {torch.__version__}")

---

## 2. 示例模型

In [None]:
class SimpleCNN(nn.Module):
    """简单 CNN 用于 TorchScript 演示。"""
    def __init__(self, num_classes: int = 10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv2d(3, 32, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2),
            nn.Conv2d(32, 64, 3, padding=1), nn.ReLU(), nn.MaxPool2d(2))
        self.classifier = nn.Sequential(nn.AdaptiveAvgPool2d(1), nn.Flatten(), nn.Linear(64, num_classes))

    def forward(self, x: Tensor) -> Tensor:
        return self.classifier(self.features(x))


class ConditionalModel(nn.Module):
    """带条件分支的模型（需要 Scripting）。"""
    def __init__(self):
        super().__init__()
        self.linear1 = nn.Linear(10, 20)
        self.linear2 = nn.Linear(10, 20)

    def forward(self, x: Tensor, use_branch_a: bool = True) -> Tensor:
        return self.linear1(x) if use_branch_a else self.linear2(x)


model = SimpleCNN(10).eval()
print(f"模型参数量: {sum(p.numel() for p in model.parameters()):,}")

---

## 3. Tracing 方式

In [None]:
def trace_model(model: nn.Module, example_input: Tensor, path: Optional[str] = None):
    """使用 Tracing 转换模型。"""
    model.eval()
    traced = torch.jit.trace(model, example_input)
    if path:
        traced.save(path)
        print(f"Saved to: {path}")
    return traced


example_input = torch.randn(1, 3, 32, 32)
traced_model = trace_model(model, example_input, "/tmp/traced.pt")

with torch.no_grad():
    diff = (model(example_input) - traced_model(example_input)).abs().max()
print(f"输出差异: {diff.item():.6f}")

---

## 4. Scripting 方式

In [None]:
def script_model(model: nn.Module, path: Optional[str] = None):
    """使用 Scripting 转换模型。"""
    model.eval()
    scripted = torch.jit.script(model)
    if path:
        scripted.save(path)
        print(f"Saved to: {path}")
    return scripted


cond_model = ConditionalModel().eval()
scripted_cond = script_model(cond_model, "/tmp/scripted.pt")

test_input = torch.randn(1, 10)
with torch.no_grad():
    out_a = scripted_cond(test_input, True)
    out_b = scripted_cond(test_input, False)
print(f"分支 A: {out_a.shape}, 分支 B: {out_b.shape}")

---

## 5. 性能对比

In [None]:
def benchmark(models: Dict[str, nn.Module], input_t: Tensor, runs: int = 100) -> Dict[str, float]:
    """基准测试推理性能。"""
    results = {}
    for name, m in models.items():
        m.eval()
        with torch.no_grad():
            for _ in range(10): m(input_t)  # warmup
            start = time.perf_counter()
            for _ in range(runs): m(input_t)
            results[name] = (time.perf_counter() - start) / runs * 1000
    return results


scripted_model = script_model(model)
models_test = {"PyTorch": model, "Traced": traced_model, "Scripted": scripted_model}

for bs in [1, 16]:
    results = benchmark(models_test, torch.randn(bs, 3, 32, 32))
    print(f"Batch {bs}: " + ", ".join(f"{k}={v:.2f}ms" for k, v in results.items()))

---

## 6. 可视化

In [None]:
def visualize_performance():
    """可视化性能对比。"""
    batch_sizes = [1, 4, 16, 32]
    all_results = {bs: benchmark(models_test, torch.randn(bs, 3, 32, 32)) for bs in batch_sizes}

    fig, ax = plt.subplots(figsize=(10, 5))
    x = np.arange(len(batch_sizes))
    w = 0.25
    
    for i, (name, color) in enumerate([("PyTorch", "blue"), ("Traced", "green"), ("Scripted", "orange")]):
        vals = [all_results[bs][name] for bs in batch_sizes]
        ax.bar(x + i*w, vals, w, label=name, color=color, alpha=0.7)
    
    ax.set_xlabel('Batch Size')
    ax.set_ylabel('Latency (ms)')
    ax.set_title('TorchScript Performance')
    ax.set_xticks(x + w)
    ax.set_xticklabels(batch_sizes)
    ax.legend()
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.show()


visualize_performance()

---

## 7. 总结

| 方式 | 函数 | 适用场景 |
|:-----|:-----|:---------|
| **Tracing** | `torch.jit.trace()` | 无控制流 |
| **Scripting** | `torch.jit.script()` | 有条件分支 |