# 结构化剪枝 (Structured Pruning)

**SOTA 教育标准** | 包含通道剪枝、BN Scale 剪枝、硬件友好稀疏

---

## 1. 结构化剪枝原理

**非结构化剪枝的问题**: 稀疏矩阵存储开销、硬件加速困难。

**结构化剪枝优势**: 直接减少张量维度，无需特殊硬件支持。

**BN Scale 剪枝**: $y = \gamma \cdot \frac{x - \mu}{\sigma} + \beta$，当 $|\gamma| \approx 0$ 时可移除通道。

In [None]:
from __future__ import annotations
from typing import Tuple
import matplotlib.pyplot as plt
import torch
import torch.nn as nn
from torch import Tensor

device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Device: {device}")

---

## 2. 通道重要性计算

In [None]:
class ChannelImportance:
    """通道重要性计算器。
    
    Core Idea: 使用 L1-norm 衡量通道重要性。
    """

    @staticmethod
    def compute(conv: nn.Conv2d) -> Tensor:
        """计算每个输出通道的重要性 (L1-norm)。"""
        return conv.weight.data.abs().sum(dim=(1, 2, 3))

    @staticmethod
    def get_mask(importance: Tensor, sparsity: float) -> Tensor:
        """获取剪枝掩码。"""
        num_channels = importance.numel()
        num_prune = int(num_channels * sparsity)
        threshold = torch.kthvalue(importance, max(num_prune, 1)).values
        return importance > threshold


# 测试
conv = nn.Conv2d(64, 128, 3, padding=1)
importance = ChannelImportance.compute(conv)
mask = ChannelImportance.get_mask(importance, 0.5)
print(f"通道数: {len(importance)}, 保留: {mask.sum().item()}")

---

## 3. 通道剪枝器

In [None]:
class ChannelPruner:
    """通道剪枝器。
    
    Summary: 移除不重要的输出通道，直接减少模型大小。
    """

    def __init__(self, sparsity: float = 0.5):
        self.sparsity = sparsity

    def prune_conv(self, conv: nn.Conv2d) -> Tuple[nn.Conv2d, Tensor]:
        """剪枝单个卷积层。"""
        importance = ChannelImportance.compute(conv)
        mask = ChannelImportance.get_mask(importance, self.sparsity)
        keep_indices = torch.where(mask)[0]

        new_conv = nn.Conv2d(
            conv.in_channels, len(keep_indices),
            conv.kernel_size, conv.stride, conv.padding,
            bias=conv.bias is not None
        )
        new_conv.weight.data = conv.weight.data[keep_indices]
        if conv.bias is not None:
            new_conv.bias.data = conv.bias.data[keep_indices]

        return new_conv, keep_indices


# 测试
conv = nn.Conv2d(64, 128, 3, padding=1)
pruner = ChannelPruner(sparsity=0.5)
new_conv, kept = pruner.prune_conv(conv)

print(f"原始通道数: {conv.out_channels}")
print(f"剪枝后通道数: {new_conv.out_channels}")
print(f"压缩率: {1 - new_conv.out_channels/conv.out_channels:.1%}")

---

## 4. BN Scale 剪枝

In [None]:
class BNScalePruner:
    """基于 BN Scale 的剪枝器。
    
    Core Idea: 使用 BatchNorm 的 gamma 参数作为通道重要性。
    """

    def __init__(self, model: nn.Module, sparsity: float = 0.5):
        self.model = model
        self.sparsity = sparsity

    def collect_bn_scales(self) -> Tensor:
        """收集所有 BN 层的 gamma 参数。"""
        scales = []
        for m in self.model.modules():
            if isinstance(m, nn.BatchNorm2d):
                scales.append(m.weight.data.abs())
        return torch.cat(scales) if scales else torch.tensor([])

    def compute_threshold(self) -> float:
        """计算全局剪枝阈值。"""
        all_scales = self.collect_bn_scales()
        if len(all_scales) == 0:
            return 0.0
        num_prune = int(len(all_scales) * self.sparsity)
        return torch.kthvalue(all_scales, max(num_prune, 1)).values.item()


# 测试
model = nn.Sequential(
    nn.Conv2d(3, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(),
    nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(),
)
pruner = BNScalePruner(model, sparsity=0.3)
threshold = pruner.compute_threshold()
print(f"BN Scale 剪枝阈值: {threshold:.4f}")

---

## 5. 可视化分析

In [None]:
def visualize_channel_importance(conv: nn.Conv2d) -> None:
    """可视化通道重要性分布。"""
    importance = ChannelImportance.compute(conv)

    fig, axes = plt.subplots(1, 2, figsize=(12, 4))

    axes[0].bar(range(len(importance)), importance.numpy())
    axes[0].set_xlabel('Channel Index')
    axes[0].set_ylabel('Importance (L1-norm)')
    axes[0].set_title('Channel Importance Distribution')
    axes[0].grid(True, alpha=0.3)

    sorted_imp, _ = torch.sort(importance)
    axes[1].plot(sorted_imp.numpy(), 'b-')
    mid = len(sorted_imp) // 2
    axes[1].axhline(sorted_imp[mid].item(), color='r', linestyle='--', label='50% threshold')
    axes[1].set_xlabel('Rank')
    axes[1].set_ylabel('Importance')
    axes[1].set_title('Sorted Channel Importance')
    axes[1].legend()
    axes[1].grid(True, alpha=0.3)

    plt.tight_layout()
    plt.show()


conv = nn.Conv2d(64, 128, 3)
visualize_channel_importance(conv)

---

## 6. 总结

| 方法 | 重要性指标 | 优势 |
|:-----|:---------|:-----|
| **L1-norm** | $\sum|W|$ | 简单直接 |
| **BN Scale** | $|\gamma|$ | 可学习 |
| **Taylor** | $|W \cdot \nabla W|$ | 考虑梯度 |