# ImageDataGenerator를 이용한 데이터 증강

In [None]:
import numpy as np
import matplotlib.pyplot as plt

import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import optimizers
from tensorflow.keras.layers import Dense, Flatten, Conv2D, MaxPooling2D, Input, Reshape

import time

In [None]:
(raw_train_x, raw_train_y), (raw_test_x, raw_test_y) = tf.keras.datasets.mnist.load_data()

train_x = raw_train_x.reshape(-1, 28, 28, 1) / 255.
test_x = raw_test_x.reshape(-1, 28, 28, 1) / 255.

train_y = raw_train_y
test_y = raw_test_y

In [None]:
model = keras.Sequential()
model.add(Input((28,28,1)))
model.add(Conv2D(32, (3, 3), padding='same'))
model.add(MaxPooling2D((2, 2)))
model.add(Conv2D(64, (3, 3), padding='same'))
model.add(MaxPooling2D((2, 2)))
model.add(Flatten())
model.add(Dense(10, activation='relu'))
model.add(Dense(10, activation='relu'))
model.add(Dense(10, activation='softmax'))


model.compile(optimizer="adam", loss="sparse_categorical_crossentropy", metrics=["accuracy"])
model.summary()




## 메모리에 있는 Numpy 데이터 활용

### TrainData 이미지 증강

In [None]:
from tensorflow.keras.preprocessing.image import ImageDataGenerator

train_data_generator = ImageDataGenerator(
    rotation_range=180,
    # width_shift_range=0.5,
    # height_shift_range=0.5,
    # vertical_flip=True,
    # horizontal_flip=True,
    # shear_range=0.5,
    # brightness_range=(0.1, 0.5),
    # zoom_range=0.9,
    # rescale = 1/255.

    # validation Dataset을 사용할 경우
    validation_split=0.2
    ).flow(train_x, train_y, batch_size=32, subset='training')

### Validation, Test 데이터 생성

In [None]:
# validation Dataset이 따로 있는 경우
valid_data_generator = ImageDataGenerator().flow(valid_x, valid_y, batch_size=32)

# train Dataset에서 사용할 경우
valid_data_generator = ImageDataGenerator(
    validation_split=0.2
    ).flow(train_x, train_y, batch_size=32, subset='validation')

test_data_generator = ImageDataGenerator().flow(test_x, test_y, batch_size=32)

In [None]:
batch_x, batch_y = next(train_data_generator)

for image in batch_x:
    plt.imshow(image.squeeze(), cmap='binary')
    plt.show()

### 모델 학습 활용

In [None]:
model.fit(train_data_generator, epochs=10, validation_data=valid_data_generator)

### Test 데이터 활용

In [None]:
y_ = model.predict(test_data_generator)

## 디렉토리 파일 Generator로 불러오기

In [None]:
train_data_generator = ImageDataGenerator(

    ).flow_from_directory(
        'some/path/train',
        target_size=(224,224),
        batch_size=32,
        class_mode='sparse'
    )