# 自动数据增强

[![在线运行](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_modelarts.png)](https://authoring-modelarts-cnnorth4.huaweicloud.com/console/lab?share-url-b64=aHR0cHM6Ly9taW5kc3BvcmUtd2Vic2l0ZS5vYnMuY24tbm9ydGgtNC5teWh1YXdlaWNsb3VkLmNvbS9ub3RlYm9vay9tb2RlbGFydHMvcHJvZ3JhbW1pbmdfZ3VpZGUvbWluZHNwb3JlX2F1dG9fYXVnbWVudGF0aW9uLmlweW5i&imageid=65f636a0-56cf-49df-b941-7d2a07ba8c8c)&emsp;[![下载Notebook](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_notebook.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/programming_guide/zh_cn/mindspore_auto_augmentation.ipynb)&emsp;[![下载样例代码](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_download_code.png)](https://obs.dualstack.cn-north-4.myhuaweicloud.com/mindspore-website/notebook/master/programming_guide/zh_cn/mindspore_auto_augmentation.py)&emsp;[![查看源文件](https://mindspore-website.obs.cn-north-4.myhuaweicloud.com/website-images/master/resource/_static/logo_source.png)](https://gitee.com/mindspore/docs/blob/master/tutorials/experts/source_zh_cn/dataset/auto_augmentation.ipynb)

MindSpore除了可以让用户自定义数据增强的使用，还提供了一种自动数据增强方式，可以基于特定策略自动对图像进行数据增强处理。

下面分为**基于概率**和**基于回调参数**两种不同的自动数据增强方式进行介绍。

## 基于概率的数据增强

MindSpore提供了一系列基于概率的自动数据增强API，用户可以对各种数据增强操作进行随机选择与组合，使数据增强更加灵活。

> 可以参见[API文档](https://www.mindspore.cn/docs/zh-CN/master/api_python/mindspore.dataset.transforms.html)。

### RandomApply操作

API接收一个数据增强操作列表`transforms`，以一定的概率顺序执行列表中各数据增强操作，默认概率为0.5，否则都不执行。

在下面的代码示例中，通过调用RandomApply接口来以0.5的概率来顺序执行`RandomCrop`和`RandomColorAdjust`操作。

In [1]:
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore.dataset.transforms.c_transforms import RandomApply

transforms_list = [c_vision.RandomCrop(512), c_vision.RandomColorAdjust()]
rand_apply = RandomApply(transforms_list)

### RandomChoice

API接收一个数据增强操作列表`transforms`，从中随机选择一个数据增强操作执行。

在下面的代码示例中，等概率地在`CenterCrop`和`RandomCrop`中选择一个操作执行。

In [2]:
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore.dataset.transforms.c_transforms import RandomChoice

rand_choice = RandomChoice([c_vision.CenterCrop(512), c_vision.RandomCrop(512)])

### RandomSelectSubpolicy

API接收一个预置策略列表，包含一系列子策略组合，每一子策略由若干个顺序执行的数据增强操作及其执行概率组成。

对各图像先等概率随机选择一种子策略，再依照子策略中的概率顺序执行各个操作。

在下面的代码示例中，预置了两条子策略，子策略1中包含`RandomRotation`、`RandomVerticalFlip`和`RandomColorAdjust`三个操作，概率分别为0.5、1.0和0.8；子策略2中包含`RandomRotation`和`RandomColorAdjust`两个操作，概率分别为1.0和0.2。

In [3]:
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore.dataset.vision.c_transforms import RandomSelectSubpolicy

policy_list = [
    [(c_vision.RandomRotation((45, 45)), 0.5), (c_vision.RandomVerticalFlip(), 1.0), (c_vision.RandomColorAdjust(), 0.8)],
    [(c_vision.RandomRotation((90, 90)), 1.0), (c_vision.RandomColorAdjust(), 0.2)]
]
policy = RandomSelectSubpolicy(policy_list)

## 基于回调参数的数据增强

MindSpore的`sync_wait`接口支持按batch或epoch粒度在训练过程中动态调整数据增强策略，用户可以设定阻塞条件触发特定的数据增强操作。

`sync_wait`将阻塞整个数据处理pipeline直到`sync_update`触发用户预先定义的`callback`函数，两者需配合使用，对应说明如下：

- sync_wait(condition_name, num_batch=1, callback=None)

    该API为数据集添加一个阻塞条件`condition_name`，当`sync_update`调用时执行指定的`callback`函数。

- sync_update(condition_name, num_batch=None, data=None)

    该API用于释放对应`condition_name`的阻塞，并对`data`触发指定的`callback`函数。

下面将演示基于回调参数的自动数据增强的用法。

1. 用户预先定义`Augment`类，其中`preprocess`为自定义的数据增强函数，`update`为更新数据增强策略的回调函数。

In [4]:
import mindspore.dataset as ds
import numpy as np

class Augment:
    def __init__(self):
        self.ep_num = 0
        self.step_num = 0

    def preprocess(self, input_):
        return np.array((input_ + self.step_num ** self.ep_num - 1),)

    def update(self, data):
        self.ep_num = data['ep_num']
        self.step_num = data['step_num']

2. 数据处理pipeline先回调自定义的增强策略更新函数`update`，然后在`map`操作中按更新后的策略来执行`preprocess`中定义的数据增强操作。

In [5]:
arr = list(range(1, 4))
dataset = ds.NumpySlicesDataset(arr, shuffle=False)
aug = Augment()
dataset = dataset.sync_wait(condition_name="policy", callback=aug.update)
dataset = dataset.map(operations=[aug.preprocess])

3. 在每个step中调用`sync_update`进行数据增强策略的更新。

In [6]:
epochs = 5
itr = dataset.create_tuple_iterator(num_epochs=epochs)
step_num = 0
for ep_num in range(epochs):
    for data in itr:
        print("epcoh: {}, step:{}, data :{}".format(ep_num, step_num, data))
        step_num += 1
        dataset.sync_update(condition_name="policy", data={'ep_num': ep_num, 'step_num': step_num})

epcoh: 0, step:0, data :[Tensor(shape=[], dtype=Int64, value= 1)]
epcoh: 0, step:1, data :[Tensor(shape=[], dtype=Int64, value= 2)]
epcoh: 0, step:2, data :[Tensor(shape=[], dtype=Int64, value= 3)]
epcoh: 1, step:3, data :[Tensor(shape=[], dtype=Int64, value= 1)]
epcoh: 1, step:4, data :[Tensor(shape=[], dtype=Int64, value= 5)]
epcoh: 1, step:5, data :[Tensor(shape=[], dtype=Int64, value= 7)]
epcoh: 2, step:6, data :[Tensor(shape=[], dtype=Int64, value= 6)]
epcoh: 2, step:7, data :[Tensor(shape=[], dtype=Int64, value= 50)]
epcoh: 2, step:8, data :[Tensor(shape=[], dtype=Int64, value= 66)]
epcoh: 3, step:9, data :[Tensor(shape=[], dtype=Int64, value= 81)]
epcoh: 3, step:10, data :[Tensor(shape=[], dtype=Int64, value= 1001)]
epcoh: 3, step:11, data :[Tensor(shape=[], dtype=Int64, value= 1333)]
epcoh: 4, step:12, data :[Tensor(shape=[], dtype=Int64, value= 1728)]
epcoh: 4, step:13, data :[Tensor(shape=[], dtype=Int64, value= 28562)]
epcoh: 4, step:14, data :[Tensor(shape=[], dtype=Int64, 

## ImageNet自动数据增强

本教程以在ImageNet数据集上实现AutoAugment作为示例。

针对ImageNet数据集的数据增强策略包含25条子策略，每条子策略中包含两种变换，针对一个batch中的每张图像随机挑选一个子策略的组合，以预定的概率来决定是否执行子策略中的每种变换。

用户可以使用MindSpore中`c_transforms`模块的`RandomSelectSubpolicy`接口来实现AutoAugment，在ImageNet分类训练中标准的数据增强方式分以下几个步骤：

- `RandomCropDecodeResize`：随机裁剪后进行解码。
- `RandomHorizontalFlip`：水平方向上随机翻转。
- `Normalize`：归一化。
- `HWC2CHW`：图片通道变化。

在`RandomCropDecodeResize`后插入AutoAugment变换，如下所示：

1. 引入MindSpore数据增强模块。

In [None]:
import matplotlib.pyplot as plt

import mindspore.dataset as ds
import mindspore.dataset.transforms.c_transforms as c_transforms
import mindspore.dataset.vision.c_transforms as c_vision
from mindspore import dtype as mstype

2. 定义MindSpore算子到AutoAugment算子的映射：

In [None]:
# define Auto Augmentation operators
PARAMETER_MAX = 10

def float_parameter(level, maxval):
    return float(level) * maxval /  PARAMETER_MAX

def int_parameter(level, maxval):
    return int(level * maxval / PARAMETER_MAX)

def shear_x(level):
    v = float_parameter(level, 0.3)
    return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(-v, -v)), c_vision.RandomAffine(degrees=0, shear=(v, v))])

def shear_y(level):
    v = float_parameter(level, 0.3)
    return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, shear=(0, 0, -v, -v)), c_vision.RandomAffine(degrees=0, shear=(0, 0, v, v))])

def translate_x(level):
    v = float_parameter(level, 150 / 331)
    return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(-v, -v)), c_vision.RandomAffine(degrees=0, translate=(v, v))])

def translate_y(level):
    v = float_parameter(level, 150 / 331)
    return c_transforms.RandomChoice([c_vision.RandomAffine(degrees=0, translate=(0, 0, -v, -v)), c_vision.RandomAffine(degrees=0, translate=(0, 0, v, v))])

def color_impl(level):
    v = float_parameter(level, 1.8) + 0.1
    return c_vision.RandomColor(degrees=(v, v))

def rotate_impl(level):
    v = int_parameter(level, 30)
    return c_transforms.RandomChoice([c_vision.RandomRotation(degrees=(-v, -v)), c_vision.RandomRotation(degrees=(v, v))])

def solarize_impl(level):
    level = int_parameter(level, 256)
    v = 256 - level
    return c_vision.RandomSolarize(threshold=(0, v))

def posterize_impl(level):
    level = int_parameter(level, 4)
    v = 4 - level
    return c_vision.RandomPosterize(bits=(v, v))

def contrast_impl(level):
    v = float_parameter(level, 1.8) + 0.1
    return c_vision.RandomColorAdjust(contrast=(v, v))

def autocontrast_impl(level):
    return c_vision.AutoContrast()

def sharpness_impl(level):
    v = float_parameter(level, 1.8) + 0.1
    return c_vision.RandomSharpness(degrees=(v, v))

def brightness_impl(level):
    v = float_parameter(level, 1.8) + 0.1
    return c_vision.RandomColorAdjust(brightness=(v, v))

3. 定义ImageNet数据集的AutoAugment策略：

In [None]:
# define the Auto Augmentation policy
imagenet_policy = [
    [(posterize_impl(8), 0.4), (rotate_impl(9), 0.6)],
    [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
    [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
    [(posterize_impl(7), 0.6), (posterize_impl(6), 0.6)],

    [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
    [(c_vision.Equalize(), 0.4), (rotate_impl(8), 0.8)],
    [(solarize_impl(3), 0.6), (c_vision.Equalize(), 0.6)],
    [(posterize_impl(5), 0.8), (c_vision.Equalize(), 1.0)],
    [(rotate_impl(3), 0.2), (solarize_impl(8), 0.6)],
    [(c_vision.Equalize(), 0.6), (posterize_impl(6), 0.4)],

    [(rotate_impl(8), 0.8), (color_impl(0), 0.4)],
    [(rotate_impl(9), 0.4), (c_vision.Equalize(), 0.6)],
    [(c_vision.Equalize(), 0.0), (c_vision.Equalize(), 0.8)],
    [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
    [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],

    [(rotate_impl(8), 0.8), (color_impl(2), 1.0)],
    [(color_impl(8), 0.8), (solarize_impl(7), 0.8)],
    [(sharpness_impl(7), 0.4), (c_vision.Invert(), 0.6)],
    [(shear_x(5), 0.6), (c_vision.Equalize(), 1.0)],
    [(color_impl(0), 0.4), (c_vision.Equalize(), 0.6)],

    [(c_vision.Equalize(), 0.4), (solarize_impl(4), 0.2)],
    [(solarize_impl(5), 0.6), (autocontrast_impl(5), 0.6)],
    [(c_vision.Invert(), 0.6), (c_vision.Equalize(), 1.0)],
    [(color_impl(4), 0.6), (contrast_impl(8), 1.0)],
    [(c_vision.Equalize(), 0.8), (c_vision.Equalize(), 0.6)],
]

4. 在`RandomCropDecodeResize`操作后插入AutoAugment变换。

In [None]:
def create_dataset(dataset_path, do_train, repeat_num=1, batch_size=32, shuffle=True, num_samples=5, target="Ascend"):
    # create a train or eval imagenet2012 dataset for ResNet-50
    dataset = ds.ImageFolderDataset(dataset_path, num_parallel_workers=8,
                                    shuffle=shuffle, num_samples=num_samples)

    image_size = 224
    mean = [0.485 * 255, 0.456 * 255, 0.406 * 255]
    std = [0.229 * 255, 0.224 * 255, 0.225 * 255]

    # define map operations
    if do_train:
        trans = [
            c_vision.RandomCropDecodeResize(image_size, scale=(0.08, 1.0), ratio=(0.75, 1.333)),
        ]

        post_trans = [
            c_vision.RandomHorizontalFlip(prob=0.5),
        ]
    else:
        trans = [
            c_vision.Decode(),
            c_vision.Resize(256),
            c_vision.CenterCrop(image_size),
            c_vision.Normalize(mean=mean, std=std),
            c_vision.HWC2CHW()
        ]
    dataset = dataset.map(operations=trans, input_columns="image")
    if do_train:
        dataset = dataset.map(operations=c_vision.RandomSelectSubpolicy(imagenet_policy), input_columns=["image"])
        dataset = dataset.map(operations=post_trans, input_columns="image")
    type_cast_op = c_transforms.TypeCast(mstype.int32)
    dataset = dataset.map(operations=type_cast_op, input_columns="label")
    # apply the batch operation
    dataset = dataset.batch(batch_size, drop_remainder=True)
    # apply the repeat operation
    dataset = dataset.repeat(repeat_num)

    return dataset

5. 验证自动数据增强效果:

In [None]:
# Define the path to image folder directory. This directory needs to contain sub-directories which contain the images
DATA_DIR = "/path/to/image_folder_directory"
dataset = create_dataset(dataset_path=DATA_DIR, do_train=True, batch_size=5, shuffle=False, num_samples=5)

epochs = 5
itr = dataset.create_dict_iterator()
columns = 5
rows = 5

step_num = 0
fig = plt.figure(figsize=(8, 8))
for ep_num in range(epochs):
    for data in itr:
        step_num += 1
        for index in range(rows):
            fig.add_subplot(rows, columns, ep_num * rows + index + 1)
            plt.imshow(data['image'].asnumpy()[index])
plt.show()

> 为了更好地演示效果，此处只加载5张图片，且读取时不进行`shuffle`操作，自动数据增强时也不进行`Normalize`和`HWC2CHW`操作。

![augment](./images/auto_augmentation.png)

运行结果可以看到，batch中每张图像的增强效果，水平方向表示1个batch的5张图像，垂直方向表示5个batch。

## 参考文献

[1] [AutoAugment: Learning Augmentation Policies from Data](https://arxiv.org/abs/1805.09501).
