# Custom Model - Subclassing

In [None]:
## 필요한 Library들을 import 합니다
import os
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
from tensorflow import keras

## TensorFlow, Keras version 확인
print(tf.__version__)
print(keras.__version__)

In [None]:
np.random.seed(777)
tf.random.set_seed(777)

## Coding Tips

#### 1. Hyper Paramter 정하기
#### 2. Data 준비(불러오기 or download 등)
#### 3. Dataset 구성 (tf.data.Dataset 이용)
#### 4. Modlel 만들기 (Neural Network model)
#### 5. Loss function 정의, Optimizer 선택
#### 6. Training (Train, Test function 만들기 포함)
#### 7. Validation(or Test) 결과 확인

In [None]:
## Hyper-parameters
learning_rate = 0.001
N_EPOCHS = 20
N_BATCH = 100
N_CLASS = 10

In [None]:
## MNIST Dataset #########################################################
mnist = keras.datasets.mnist
class_names = ['0', '1', '2', '3', '4', '5', '6', '7', '8', '9']
##########################################################################

## Fashion MNIST Dataset #################################################
#mnist = keras.datasets.fashion_mnist
#class_names = ['T-shirt/top', 'Trouser', 'Pullover', 'Dress', 'Coat', 'Sandal', 'Shirt', 'Sneaker', 'Bag', 'Ankle boot']
##########################################################################

In [None]:
## MNIST dataset load
(train_images, train_labels), (test_images, test_labels) = mnist.load_data() 

In [None]:
N_TRAIN = train_images.shape[0]
N_TEST = test_images.shape[0]

In [None]:
# pixel값을 0~1사이 범위로 조정
train_images = train_images.astype(np.float32) / 255.
test_images = test_images.astype(np.float32) / 255.
# CNN에 입력으로 넣기 위해 3차원->4차원으로 변경(channel에 1을 추가)
train_images = train_images[..., tf.newaxis]
test_images = test_images[..., tf.newaxis]
# label을 onehot-encoding
train_labels = keras.utils.to_categorical(train_labels, 10)
test_labels = keras.utils.to_categorical(test_labels, 10)

In [None]:
## dataset 구성    
train_dataset = tf.data.Dataset.from_tensor_slices((train_images, train_labels)).shuffle(
                buffer_size=100000).batch(N_BATCH)
test_dataset = tf.data.Dataset.from_tensor_slices((test_images, test_labels)).batch(N_BATCH)

In [None]:
class MyModel(keras.Model):
  def __init__(self):
    super(MyModel, self).__init__()    
    self.conv1 = keras.layers.Conv2D(filters=32, kernel_size=[3, 3], padding='SAME', activation='relu')
    self.pool1 = keras.layers.MaxPool2D(padding='SAME')
    self.conv2 = keras.layers.Conv2D(filters=64, kernel_size=[3, 3], padding='SAME', activation='relu')
    self.pool2 = keras.layers.MaxPool2D(padding='SAME')
    self.conv3 = keras.layers.Conv2D(filters=128, kernel_size=[3, 3], padding='SAME', activation='relu')
    self.pool3 = keras.layers.MaxPool2D(padding='SAME')
    self.pool3_flat = keras.layers.Flatten()
    self.dense4 = keras.layers.Dense(units=256, activation='relu')
    self.drop4 = keras.layers.Dropout(rate=0.4)
    self.dense5 = keras.layers.Dense(units=10, activation='softmax')
  
  def call(self, x, training=False):
    x = self.conv1(x)
    x = self.pool1(x)
    x = self.conv2(x)
    x = self.pool2(x)
    x = self.conv3(x)
    x = self.pool3(x)
    x = self.pool3_flat(x)
    x = self.dense4(x)
    x = self.drop4(x)
    return self.dense5(x)    

In [None]:
model = MyModel()

In [None]:
## model compile
model.compile(optimizer=keras.optimizers.Adam(learning_rate),
                 loss='categorical_crossentropy',
                 metrics=['accuracy'])

In [None]:
## Parameters for training
steps_per_epoch = N_TRAIN//N_BATCH
validation_steps = N_TEST//N_BATCH
print(steps_per_epoch, validation_steps)

In [None]:
## Training
history = model.fit(train_dataset, epochs=N_EPOCHS, steps_per_epoch=steps_per_epoch, 
                    validation_data=test_dataset, validation_steps=validation_steps)

In [None]:
cur_dir = os.getcwd()
save_dir_name = 'saved_models'
os.makedirs(save_dir_name, exist_ok=True)

In [None]:
## saved_model format
saved_model_path = os.path.join(cur_dir, save_dir_name, 'mnist_cnn_subclass_pb')
model.save(saved_model_path)

In [None]:
new_model7 = keras.models.load_model(saved_model_path)
new_model7.evaluate(test_dataset)