# 6차시: 텐서플로우 전이학습

## 2023. 08. 02. 14:10 ~ 16:00 (50분×2)
1. 전이학습 기초
1. CIFAR-10 데이터 셋
1. 텐서플로우 실습

### 참고자료
- [파이썬 3 표준 문서](https://docs.python.org/3/index.html)
- [텐서플로우 이미지 분류](https://www.tensorflow.org/tutorials/keras/classification)
- [텐서플로우 전이학습](https://www.tensorflow.org/tutorials/images/transfer_learning_with_hub)

### 1. 도구 불러오기 및 버전 확인

In [None]:
# 도구 준비
import os
import random

import tensorflow as tf # 텐서플로우
import tensorflow_hub as hub
import matplotlib.pyplot as plt # 시각화 도구
%matplotlib inline
import numpy as np
import PIL.Image as Image

In [None]:
print(f'Tensorflow 버전을 확인합니다: {tf.__version__}')

### 2. 학습 데이터 불러오기

In [None]:
# prepare dataset
dataset_root = os.path.abspath(os.path.expanduser('/tmp/dataset'))
print(f'Dataset root: {dataset_root}')

IMAGE_SHAPE = (128, 128) # 자신의 데이터 셋에 맞추어서 조정!
image_generator = tf.keras.preprocessing.image.ImageDataGenerator(rescale=1/255,
                                                                  validation_split=0.2)
train_data = image_generator.flow_from_directory(dataset_root, target_size=IMAGE_SHAPE,
                                                 subset='training')
validation_data = image_generator.flow_from_directory(dataset_root, target_size=IMAGE_SHAPE,
                                                 subset='validation')

for image_batch, label_batch in validation_data:
    print(f'Image batch shape: {image_batch.shape}')
    print(f'Label batch shape: {label_batch.shape}')
    break

### 3. 학습 데이터 살펴보기

In [None]:
classifier_url = 'https://tfhub.dev/google/imagenet/inception_v3/classification/4'

classifier = tf.keras.Sequential([
    hub.KerasLayer(classifier_url, input_shape=IMAGE_SHAPE+(3,)) # Channel 3 RGB
])

labels_path = tf.keras.utils.get_file('ImageNetLabels.txt', 
                                      'https://storage.googleapis.com/download.tensorflow.org/data/ImageNetLabels.txt')
imagenet_labels = np.array(open(labels_path).read().splitlines())

In [None]:
## using original ImangeNet classifier
result_batch = classifier.predict(image_batch)
print(f'Batch result shape: {result_batch.shape}')

predicted_class_names = imagenet_labels[np.argmax(result_batch, axis=-1)]
print(f'Batch predicted class names: {predicted_class_names}')

fig = plt.figure(figsize=(10, 10.5))
for n in range(30):
    ax = fig.add_subplot(6, 5, n+1)
    ax.imshow(image_batch[n])
    ax.set_title(predicted_class_names[n])
    ax.axis('off')
_ = fig.suptitle('ImageNet predictions')

### 4. Convolution Neural Network

In [None]:
## Log class
### https://www.tensorflow.org/api_docs/python/tf/keras/callbacks/Callback
class CollectBatchStats(tf.keras.callbacks.Callback):
    def __init__(self):
        self.batch_losses = []
        self.batch_val_losses = []
        self.batch_acc = []
        self.batch_val_acc = []
    
    def on_epoch_end(self, epoch, logs=None):
        self.batch_losses.append(logs['loss'])
        self.batch_acc.append(logs['accuracy'])
        self.batch_val_losses.append(logs['val_loss'])
        self.batch_val_acc.append(logs['val_accuracy'])
        self.model.reset_metrics()

In [None]:
cnn_model = tf.keras.Sequential([
        tf.keras.layers.Conv2D(32, (3, 3), activation='relu', input_shape=image_batch.shape[1:]),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.MaxPooling2D((2, 2)),
        tf.keras.layers.Conv2D(64, (3, 3), activation='relu'),
        tf.keras.layers.Flatten(),
        tf.keras.layers.Dense(64, activation='relu'),
        tf.keras.layers.Dense(train_data.num_classes)])

cnn_model.summary()

### 5. Train model

In [None]:
base_learning_rate = 0.001

cnn_model.compile(optimizer=tf.keras.optimizers.Adam(learning_rate=base_learning_rate),
                  loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                  metrics=['accuracy'])

In [None]:
steps_per_epoch = np.ceil(train_data.samples/train_data.batch_size) # train all dataset per epoch
epochs = 25*2
cnn_callback = CollectBatchStats()

cnn_history = cnn_model.fit(train_data,
                            epochs=epochs,
                            steps_per_epoch=steps_per_epoch,
                            validation_data=validation_data,
                            callbacks=[cnn_callback])