<a href="https://colab.research.google.com/github/marsggbo/AutoMLDemos/blob/main/ch7/searchable_data_aug.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# 可搜索数据增强算法

本教程将介绍如何使用 PyTorch 和 Kornia 实现可搜索的数据增强算法，包括：
1. 基础的 3D 数据增强算子
2. 可搜索的数据增强策略（使用 OperationSpace）
3. 与模型结构融合的可搜索模型模块

## 1. 安装依赖

In [9]:
# 安装必要的依赖包
!pip install -q torch torchvision kornia einops
!pip install -q git+https://github.com/marsggbo/hyperbox.git

## 2. 导入必要的库

In [10]:
import torch
import torch.nn as nn
import numpy as np
import kornia
from kornia.augmentation import *

## 3. 基础 3D 数据增强算子

首先，我们实现基础的 3D 数据增强算子。这些算子基于 Kornia 库，将 2D 操作扩展到 3D 数据（如医学影像）。

In [11]:
class Base2dTo3d(nn.Module):
    """将 2D 数据增强操作扩展到 3D 数据的基类"""
    def __init__(self):
        super().__init__()

    def forward(self, x):
        bs = x.shape[0]
        assert len(x.shape) == 5, f"输入形状应为 5 维: B,C,D,H,W，当前为 {x.shape}"
        B, C, D, H, W = x.shape
        # 将 3D 数据重塑为 2D: (B*D, C, H, W)
        x = x.permute(0, 2, 1, 3, 4).contiguous()  # (B, D, C, H, W)
        x = x.view(B * D, C, H, W)  # (B*D, C, H, W)
        # 应用 2D 增强操作
        x = self.aug(x)
        # 重塑回 3D: (B, C, D, H, W)
        BxD, C, H, W = x.shape
        x = x.view(bs, D, C, H, W)  # (B, D, C, H, W)
        x = x.permute(0, 2, 1, 3, 4).contiguous()  # (B, C, D, H, W)
        return x


class RandomInvert3d(Base2dTo3d):
    """随机反转 3D 图像"""
    def __init__(self, max_val=1.0, p=0.5):
        super(RandomInvert3d, self).__init__()
        self.aug = RandomInvert(max_val, p=p)


class RandomGaussianNoise3d(Base2dTo3d):
    """随机添加高斯噪声"""
    def __init__(self, mean=0.0, std=1.0, p=0.5):
        super(RandomGaussianNoise3d, self).__init__()
        self.aug = RandomGaussianNoise(mean, std, p=p)


class RandomBoxBlur3d(Base2dTo3d):
    """随机盒状模糊"""
    def __init__(self, kernel_size=(3, 3), p=0.5):
        super(RandomBoxBlur3d, self).__init__()
        self.aug = RandomBoxBlur(kernel_size, p=p)


class RandomErasing3d(Base2dTo3d):
    """随机擦除（Random Erasing）"""
    def __init__(self, scale=(0.02, 0.33), ratio=(0.3, 3.3), p=0.5):
        super(RandomErasing3d, self).__init__()
        self.aug = RandomErasing(scale, ratio, p=p)


class RandomSharpness3d(Base2dTo3d):
    """随机锐化"""
    def __init__(self, sharpness=0.5, p=0.5):
        super(RandomSharpness3d, self).__init__()
        self.aug = RandomSharpness(sharpness, p=p)


class BrightContrast3d(nn.Module):
    """亮度和对比度调整"""
    def __init__(self, brightness=0.4, contrast=0.4, p=0.5):
        super(BrightContrast3d, self).__init__()
        self.brightness = brightness
        self.contrast = contrast
        self.p = p
        self.bright_op = kornia.enhance.adjust_brightness
        self.contr_op = kornia.enhance.adjust_contrast

    def forward(self, x):
        prob = torch.rand(1).item()
        if prob < self.p:
            x = self.bright_op(x, self.brightness)
            x = self.contr_op(x, self.contrast)
        return x

### 3.1 测试基础数据增强算子

In [12]:
# 创建测试数据: (batch_size, channels, depth, height, width)
x = torch.rand(2, 3, 8, 32, 32)
print(f"原始数据形状: {x.shape}")

# 测试各种增强操作
ops = {
    'RandomInvert3d': RandomInvert3d(),
    'RandomGaussianNoise3d': RandomGaussianNoise3d(),
    'RandomBoxBlur3d': RandomBoxBlur3d(),
    'RandomErasing3d': RandomErasing3d(),
    'RandomSharpness3d': RandomSharpness3d(),
    'BrightContrast3d': BrightContrast3d(),
}

for name, op in ops.items():
    y = op(x)
    print(f"{name}: 输出形状 {y.shape} ✓")

原始数据形状: torch.Size([2, 3, 8, 32, 32])
RandomInvert3d: 输出形状 torch.Size([2, 3, 8, 32, 32]) ✓
RandomGaussianNoise3d: 输出形状 torch.Size([2, 3, 8, 32, 32]) ✓
RandomBoxBlur3d: 输出形状 torch.Size([2, 3, 8, 32, 32]) ✓
RandomErasing3d: 输出形状 torch.Size([2, 3, 8, 32, 32]) ✓
RandomSharpness3d: 输出形状 torch.Size([2, 3, 8, 32, 32]) ✓
BrightContrast3d: 输出形状 torch.Size([2, 3, 8, 32, 32]) ✓


## 4. 可搜索的数据增强策略

接下来，我们使用 `OperationSpace` 来构建可搜索的数据增强策略。这样可以通过 NAS 算法自动搜索最优的数据增强组合。

In [13]:
from hyperbox.mutables.spaces import OperationSpace
from hyperbox.networks.base_nas_network import BaseNASNetwork

# 注意：以下代码需要 kornia 支持 3D 操作，或者需要自定义 3D 翻转和裁剪操作

def prob_list_gen(func, num_probs=4, probs=None, *args, **kwargs):
    """生成不同概率的操作列表"""
    if probs is not None:
        return [func(p=p, *args, **kwargs) for p in probs]
    else:
        return [func(p=p, *args, **kwargs) for p in [i*0.25 for i in range(num_probs)]]


# 简化的 3D 翻转操作（基于 Base2dTo3d）
class RandomHorizontalFlip3d(Base2dTo3d):
    def __init__(self, p=0.5):
        super().__init__()
        self.aug = RandomHorizontalFlip(p=p)


class RandomVerticalFlip3d(Base2dTo3d):
    def __init__(self, p=0.5):
        super().__init__()
        self.aug = RandomVerticalFlip(p=p)


# 简化的 3D 裁剪操作
class RandomCrop3d(nn.Module):
    """3D 随机裁剪（简化版）"""
    def __init__(self, size, p=1.0):
        super().__init__()
        self.size = size
        self.p = p
    
    def forward(self, x):
        if torch.rand(1).item() < self.p:
            B, C, D, H, W = x.shape
            d, h, w = self.size
            # 随机裁剪起始位置
            start_d = torch.randint(0, max(1, D - d + 1), (1,)).item()
            start_h = torch.randint(0, max(1, H - h + 1), (1,)).item()
            start_w = torch.randint(0, max(1, W - w + 1), (1,)).item()
            x = x[:, :, start_d:start_d+d, start_h:start_h+h, start_w:start_w+w]
        return x


class DataAugmentation(BaseNASNetwork):
    """可搜索的数据增强模块"""
    
    def __init__(
        self,
        crop_size=[(16, 64, 64)],
        mask=None
    ):
        super().__init__(mask)
        
        # 定义可搜索的操作空间
        self.ops = {}
        
        # 1. 翻转操作
        self.ops['hflip'] = prob_list_gen(
            RandomHorizontalFlip3d, probs=[0, 0.5, 1]
        )
        self.ops['vflip'] = prob_list_gen(
            RandomVerticalFlip3d, probs=[0, 0.5, 1]
        )
        
        # 2. 随机裁剪
        self.ops['rcrop'] = []
        for size in crop_size:
            self.ops['rcrop'].append(RandomCrop3d(size=size, p=1))
        
        # 3. 盒状模糊（包含 Identity）
        boxblur = [nn.Identity()]
        for ks in [(3, 3), (5, 5)]:
            boxblur += prob_list_gen(RandomBoxBlur3d, probs=[0.5, 1], kernel_size=ks)
        self.ops['boxblur'] = boxblur
        
        # 4. 反转（包含 Identity）
        invert = [nn.Identity()]
        for val in [0.25, 0.5, 0.75, 1]:
            invert += prob_list_gen(RandomInvert3d, probs=[0.5, 1], max_val=val)
        self.ops['invert'] = invert
        
        # 5. 高斯噪声（包含 Identity）
        gauNoise = [nn.Identity()]
        gauNoise += prob_list_gen(RandomGaussianNoise3d, probs=[0, 0.5, 1])
        self.ops['gauNoise'] = gauNoise
        
        # 6. 随机擦除（包含 Identity）
        erase = [nn.Identity()]
        for scale in [(0.02, 0.1), (0.1, 0.33)]:
            erase += prob_list_gen(RandomErasing3d, probs=[0.5, 1], scale=scale)
        self.ops['erase'] = erase
        
        # 将操作空间转换为 OperationSpace
        transforms = []
        for key, value in self.ops.items():
            if value:
                transforms.append(
                    OperationSpace(
                        candidates=value, 
                        key=key, 
                        mask=self.mask, 
                        reduction='mean'
                    )
                )
        self.transforms = nn.Sequential(*transforms)
    
    def forward(self, x: torch.Tensor, aug=True):
        """前向传播"""
        if aug:
            for trans in self.transforms:
                x = trans(x)  # BxCxDxHxW
        return x
    
    @property
    def arch(self):
        """获取当前架构"""
        _arch = []
        for op in self.transforms:
            mask = op.mask
            if 'bool' in str(mask.dtype):
                index = mask.int().argmax()
            else:
                index = mask.float().argmax()
            _arch.append(f"{op.candidates[index]}")
        return '\n'.join(_arch)

### 4.1 测试可搜索数据增强模块

In [21]:
from hyperbox.mutator import OnehotMutator, RandomMutator

# 创建可搜索的数据增强模块
data_aug = DataAugmentation(crop_size=[(8, 32, 32), (16, 64, 64)])

# 创建 mutator 来管理搜索空间
mutator = RandomMutator(data_aug)

# 测试不同的架构
print("测试不同的数据增强架构：\n")
for i in range(3):
    mutator.reset()
    print(f"架构 {i+1}:")
    print(data_aug.arch)
    
    # 测试前向传播
    x = torch.rand(2, 1, 8, 64, 64)
    y = data_aug(x)
    print(f"输入形状: {x.shape}, 输出形状: {y.shape}\n")

测试不同的数据增强架构：

架构 1:
RandomHorizontalFlip3d(
  (aug): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False)
)
RandomVerticalFlip3d(
  (aug): RandomVerticalFlip(p=0.5, p_batch=1.0, same_on_batch=False)
)
RandomCrop3d()
RandomBoxBlur3d(
  (aug): RandomBoxBlur(p=0.5, p_batch=1.0, same_on_batch=False, kernel_size=(3, 3), border_type=reflect, normalized=True)
)
RandomInvert3d(
  (aug): RandomInvert(p=1, p_batch=1.0, same_on_batch=False, max_val=0.25)
)
RandomGaussianNoise3d(
  (aug): RandomGaussianNoise(p=1, p_batch=1.0, same_on_batch=False, mean=0.0, std=1.0)
)
RandomErasing3d(
  (aug): RandomErasing(scale=(0.1, 0.33), resize_to=(0.3, 3.3), value=0.0, p=1, p_batch=1.0, same_on_batch=False)
)
输入形状: torch.Size([2, 1, 8, 64, 64]), 输出形状: torch.Size([2, 1, 8, 32, 32])

架构 2:
RandomHorizontalFlip3d(
  (aug): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False)
)
RandomVerticalFlip3d(
  (aug): RandomVerticalFlip(p=0, p_batch=1.0, same_on_batch=False)
)
RandomCrop3d()
RandomBoxBlur

## 5. 与模型结构融合的可搜索模型模块

最后，我们展示如何将可搜索的数据增强与模型架构搜索相结合，构建端到端的可搜索模型。

In [19]:
# 简化的模型示例：将数据增强与简单分类器结合
class SimpleClassifier(nn.Module):
    """简单的 3D CNN 分类器"""
    def __init__(self, in_channels=1, num_classes=10):
        super().__init__()
        self.features = nn.Sequential(
            nn.Conv3d(in_channels, 32, kernel_size=3, padding=1),
            nn.BatchNorm3d(32),
            nn.ReLU(inplace=True),
            nn.AdaptiveAvgPool3d(1),
            nn.Flatten(),
            nn.Linear(32, num_classes)
        )
    
    def forward(self, x):
        return self.features(x)


class SearchableModelWithAug(BaseNASNetwork):
    """结合可搜索数据增强和模型架构的端到端模型"""
    def __init__(
        self,
        in_channels=1,
        num_classes=10,
        mask=None
    ):
        super().__init__(mask)
        
        # 可搜索的数据增强模块
        self.data_aug = DataAugmentation(
            crop_size=[(8, 32, 32), (16, 64, 64)],
            mask=self.mask
        )
        
        # 分类器（这里简化，实际可以使用可搜索的架构）
        self.classifier = SimpleClassifier(in_channels, num_classes)
    
    def forward(self, x, aug=True):
        # 应用数据增强
        x = self.data_aug(x, aug=aug)
        # 分类
        logits = self.classifier(x)
        return logits
    
    @property
    def arch(self):
        """获取当前架构（包括数据增强策略）"""
        return f"Data Augmentation:\n{self.data_aug.arch}\n\nClassifier: SimpleClassifier"

### 5.1 测试端到端可搜索模型

In [23]:
# 创建端到端可搜索模型
model = SearchableModelWithAug(in_channels=1, num_classes=10)

# 创建 mutator
mutator = RandomMutator(model)

# 测试模型
print("测试端到端可搜索模型：\n")
for i in range(2):
    mutator.reset()
    print(f"模型架构 {i+1}:")
    print(model.arch)
    print("\n" + "="*50 + "\n")
    
    # 测试前向传播
    x = torch.rand(4, 1, 8, 64, 64)
    logits = model(x, aug=True)
    print(f"输入形状: {x.shape}")
    print(f"输出 logits 形状: {logits.shape}")
    # 将 tensor 转换为列表再打印，避免 PyTorch 版本兼容性问题
    pred_classes = logits.argmax(dim=1).detach().cpu().tolist()
    print(f"预测类别: {pred_classes}\n")

测试端到端可搜索模型：

模型架构 1:
Data Augmentation:
RandomHorizontalFlip3d(
  (aug): RandomHorizontalFlip(p=0.5, p_batch=1.0, same_on_batch=False)
)
RandomVerticalFlip3d(
  (aug): RandomVerticalFlip(p=0, p_batch=1.0, same_on_batch=False)
)
RandomCrop3d()
RandomBoxBlur3d(
  (aug): RandomBoxBlur(p=1, p_batch=1.0, same_on_batch=False, kernel_size=(5, 5), border_type=reflect, normalized=True)
)
RandomInvert3d(
  (aug): RandomInvert(p=0.5, p_batch=1.0, same_on_batch=False, max_val=1)
)
RandomGaussianNoise3d(
  (aug): RandomGaussianNoise(p=0, p_batch=1.0, same_on_batch=False, mean=0.0, std=1.0)
)
RandomErasing3d(
  (aug): RandomErasing(scale=(0.1, 0.33), resize_to=(0.3, 3.3), value=0.0, p=1, p_batch=1.0, same_on_batch=False)
)

Classifier: SimpleClassifier


输入形状: torch.Size([4, 1, 8, 64, 64])
输出 logits 形状: torch.Size([4, 10])
预测类别: [9, 6, 9, 9]

模型架构 2:
Data Augmentation:
RandomHorizontalFlip3d(
  (aug): RandomHorizontalFlip(p=1, p_batch=1.0, same_on_batch=False)
)
RandomVerticalFlip3d(
  (aug): Random

## 6. 总结

本教程展示了：

1. **基础数据增强算子**：基于 Kornia 实现的 3D 数据增强操作
2. **可搜索数据增强策略**：使用 `OperationSpace` 构建可搜索的数据增强模块
3. **端到端可搜索模型**：将数据增强与模型架构搜索相结合

### 关键点：

- **Base2dTo3d**：将 2D 操作扩展到 3D 数据的基类
- **OperationSpace**：用于构建可搜索的操作空间
- **Mutator**：管理搜索空间，支持随机采样、One-hot 编码等策略

### 下一步：

- 可以使用 DARTS、随机搜索等 NAS 算法来搜索最优的数据增强策略
- 可以将数据增强策略与更复杂的模型架构（如 MobileNet、ResNet）结合
- 在实际数据集上评估不同数据增强策略的效果