In [19]:
import tensorflow as tf
from tensorflow.keras import utils
from tensorflow.keras.datasets import mnist
from tensorflow.data import Dataset

In [2]:
batch_size = 100
split = 0.8

In [16]:
### keras.datasets.mnist  (NumPy Array)
(tra_im, tra_lb), (tes_im, tes_lb) = mnist.load_data()
# 正規化
tra_im_norm = tra_im / 255.0
tes_im_norm = tes_im / 255.0
# one-hot encoding
tra_lb_onehot = utils.to_categorical(tra_lb)
tes_lb_onehot = utils.to_categorical(tes_lb)

In [4]:
# 分離訓練資料  -->  [0.8, 0.2] = [train, valid]
split_idx = int(len(tra_im)*split)

# training data
tra_ds_im = Dataset.from_tensor_slices(tra_im_norm[:split_idx])    # 影像 Dataset
tra_ds_lb = Dataset.from_tensor_slices(tra_lb_onehot[:split_idx])  # 標記 Dataset
tra_ds = Dataset.zip((tra_ds_im, tra_ds_lb))  # 影像、標記整合成一個 Dataset
tra_ds = tra_ds.batch(batch_size)  # 設定 Dataset 批次大小
tra_ds = tra_ds.shuffle(split_idx) # 打亂 Dataset

# validation data
val_ds_im = Dataset.from_tensor_slices(tra_im_norm[split_idx:])    # 影像 Dataset
val_ds_lb = Dataset.from_tensor_slices(tra_lb_onehot[split_idx:])  # 標記 Dataset
val_ds = Dataset.zip((val_ds_im, val_ds_lb))  # 影像、標記整合成一個 Dataset
val_ds = val_ds.batch(batch_size)  # 設定 Dataset 批次大小
val_ds = val_ds.shuffle(len(tra_im)-split_idx) # 打亂 Dataset

# testing data
tes_ds_im = Dataset.from_tensor_slices(tes_im_norm)    # 影像 Dataset
tes_ds_lb = Dataset.from_tensor_slices(tes_lb_onehot)  # 標記 Dataset
tes_ds = Dataset.zip((tes_ds_im, tes_ds_lb))  # 影像、標記整合成一個 Dataset
tes_ds = tes_ds.batch(batch_size)  # 設定 Dataset 批次大小

---

In [5]:
def flip_h(x):
    x = tf.image.random_flip_left_right(x)
    return x
def flip_v(x):
    x = tf.image.random_flip_up_down(x)
    return x
def rotate(x):
    k = tf.random.uniform([], 1, 4, tf.int32)
    x = tf.image.rot90(x, k)
    return x
def hue(x, val=0.08):  # 色調
    x = tf.image.random_hue(x, val)
    return x
def brightness(x, val=0.05):  # 亮度
    x = tf.image.random_brightness(x, val)
    return x
def saturation(x, minval=0.6, maxval=1.6):  # 飽和度
    x = tf.image.random_saturation(x, minval, maxval)
    return x
def contrast(x, minval=0.7, maxval=1.3):  # 對比度
    x = tf.image.random_contrast(x, minval, maxval)
    return x
def zoom(x, scale_minval=0.5, scale_maxval=1.5):
    height, width, channel = x.shape
    scale = tf.random.uniform([], scale_minval, scale_maxval)
    new_size = (scale*height, scale*width)
    x = tf.image.resize(x, new_size)
    x = tf.image.resize_with_crop_or_pad(x, height, width)
    return x

In [20]:
def parse_fn(dataset, **kwargs):
    if kwargs:
        print("Data Augmentation!!!")
        for k, v in kwargs.items():
            print("%15s:"%k, v)
    else:
        print("Not Data Augmentation!!!")
    print()

    # 分離 dataset
    x = dataset["image"]
    y = dataset["label"]

    # 對影像進行正規化，及增加影像通道
    # 因為 MNIST 是灰階影像，所以要自行加上通道，也就是第三軸
    # 從 sahpe 來看就是變成 (28, 28)  ==>  (28, 28, 1)
    x = tf.cast(x, tf.float32) / 255.0
    x = tf.expand_dims(x, axis=-1)

    # 從 kwargs 來判斷哪些擴增需要執行
    if kwargs.get("flip_h", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: flip_h(x), lambda: x)
    if kwargs.get("flip_v", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: flip_v(x), lambda: x)
    if kwargs.get("rotate", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: rotate(x), lambda: x)
    if kwargs.get("hue", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: hue(x), lambda: x)
    if kwargs.get("brightness", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: brightness(x), lambda: x)
    if kwargs.get("saturation", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: saturation(x), lambda: x)
    if kwargs.get("contrast", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: contrast(x), lambda: x)
    if kwargs.get("zoom_scale", None):
        x = tf.cond(tf.random.uniform((), 0, 1) > 0.5,
                    lambda: zoom(x), lambda: x)

    # 對標記進行 one-hot encoding
    y = tf.one_hot(y, 10)
    return {"image": x}, {"label": y} # 回傳資料 (個人喜歡採用 dict 形式)

In [25]:
batch_size = 100
split = 0.8
augdict = {
    # "flip_h": True,
    # "flip_v": True,
    # "rotate": True,
    "hue": False,
    "saturation": False,
    "contrast": True,
    "brightness": True,
    "zoom_scale": False,
}

In [27]:
(X_train, Y_train), (X_test, Y_test) = mnist.load_data()
#-----------------------------------------------------------------------------#
split_idx = int(len(X_train) * split)
train_data = {"image": X_train[:split_idx], "label": Y_train[:split_idx]}
val_data = {"image": X_train[split_idx:], "label": Y_train[split_idx:]}
test_data = {"image": X_test, "label": Y_test}
#-----------------------------------------------------------------------------#
train_datasets = Dataset.from_tensor_slices(train_data)
val_datasets = Dataset.from_tensor_slices(val_data)
test_datasets = Dataset.from_tensor_slices(test_data)

In [23]:
autotune = tf.data.experimental.AUTOTUNE
#-----------------------------------------------------------------------------#
train_datasets = train_datasets.map(lambda ds: parse_fn(ds, **augdict), num_parallel_calls=autotune)
train_datasets = train_datasets.shuffle(1000).batch(batch_size)
#-----------------------------------------------------------------------------#
val_datasets = val_datasets.map(lambda ds: parse_fn(ds, **augdict), num_parallel_calls=autotune)
val_datasets = val_datasets.shuffle(1000).batch(batch_size)
#-----------------------------------------------------------------------------#
test_datasets = test_datasets.map(parse_fn, num_parallel_calls=autotune)
test_datasets = test_datasets.batch(batch_size)

dict_items([('hue', False), ('saturation', False), ('contrast', True), ('brightness', True), ('zoom_scale', False)])
Data Augmentation!!!
            hue: False
     saturation: False
       contrast: True
     brightness: True
     zoom_scale: False

dict_items([('hue', False), ('saturation', False), ('contrast', True), ('brightness', True), ('zoom_scale', False)])
Data Augmentation!!!
            hue: False
     saturation: False
       contrast: True
     brightness: True
     zoom_scale: False

dict_items([])
Not Data Augmentation!!!

