## 데이터셋 로드

In [None]:
import tensorflow as tf
from tensorflow.keras.layers import Dense, Flatten, Input
from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.callbacks import ModelCheckpoint

In [None]:
# 데이터셋 로드
mnist = tf.keras.datasets.mnist
(x_train, y_train), (x_valid, y_valid) = mnist.load_data()

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


In [None]:
x_train.shape, x_valid.shape

((60000, 28, 28), (10000, 28, 28))

In [None]:
y_train.shape, y_valid.shape

((60000,), (10000,))

## Sequential

(Flatten)           (None, 784)                
_________________________________________________________________
(Dense)             (None, 256)            
_________________________________________________________________
(Dense)             (None, 128)             
_________________________________________________________________
(Dense)             (None, 64)             
_______________________________________________________________
(Dense)             (None, 32)             
_________________________________________________________________
(Dense)             (None, 10)             

In [None]:
model = Sequential([
    Flatten(input_shape=(28, 28)),
    Dense(256, activation='relu'),
    Dense(128, activation='relu'), 
    Dense(64, activation='relu'),
    Dense(32, activation='relu'),
    Dense(10, activation='softmax'),
])

In [None]:
model.summary()

Model: "sequential"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten (Flatten)            (None, 784)               0         
_________________________________________________________________
dense (Dense)                (None, 256)               200960    
_________________________________________________________________
dense_1 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_2 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_3 (Dense)              (None, 32)                2080      
_________________________________________________________________
dense_4 (Dense)              (None, 10)                330       
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
__________________________________________________

## 함수형 (Functional API)

(Flatten)           (None, 784)                
_________________________________________________________________
(Dense)             (None, 256)            
_________________________________________________________________
(Dense)             (None, 128)             
_________________________________________________________________
(Dense)             (None, 64)             
_______________________________________________________________
(Dense)             (None, 32)             
_________________________________________________________________
(Dense)             (None, 10)             

In [None]:
input_ = Input(shape=(28, 28))
x = Flatten()(input_)
x = Dense(256, activation='relu')(x)
x = Dense(128, activation='relu')(x)
x = Dense(64, activation='relu')(x)
x = Dense(32, activation='relu')(x)
x = Dense(10, activation='softmax')(x)

In [None]:
model = Model(input_, x)

In [None]:
model.summary()

Model: "model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
input_1 (InputLayer)         [(None, 28, 28)]          0         
_________________________________________________________________
flatten_1 (Flatten)          (None, 784)               0         
_________________________________________________________________
dense_5 (Dense)              (None, 256)               200960    
_________________________________________________________________
dense_6 (Dense)              (None, 128)               32896     
_________________________________________________________________
dense_7 (Dense)              (None, 64)                8256      
_________________________________________________________________
dense_8 (Dense)              (None, 32)                2080      
_________________________________________________________________
dense_9 (Dense)              (None, 10)                330   

## Sub-Classing

(Flatten)           (None, 784)                
_________________________________________________________________
(Dense)             (None, 256)            
_________________________________________________________________
(Dense)             (None, 128)             
_________________________________________________________________
(Dense)             (None, 64)             
_______________________________________________________________
(Dense)             (None, 32)             
_________________________________________________________________
(Dense)             (None, 10)             

In [None]:
class MyModel(Model):
    def __init__(self):
        super(MyModel, self).__init__()
        self.flatten = Flatten()
        self.dense1 = Dense(256, activation='relu')
        self.dense2 = Dense(128, activation='relu')
        self.dense3 = Dense(64, activation='relu')
        self.dense4 = Dense(32, activation='relu')
        self.dense5 = Dense(10, activation='softmax')

    def call(self, x):
        x = self.flatten(x)
        x = self.dense1(x)
        x = self.dense2(x)
        x = self.dense3(x)
        x = self.dense4(x)
        x = self.dense5(x)
        return x


In [None]:
model = MyModel()
model(input_)

<KerasTensor: shape=(None, 10) dtype=float32 (created by layer 'my_model')>

In [None]:
model.summary()

Model: "my_model"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
flatten_2 (Flatten)          multiple                  0         
_________________________________________________________________
dense_10 (Dense)             multiple                  200960    
_________________________________________________________________
dense_11 (Dense)             multiple                  32896     
_________________________________________________________________
dense_12 (Dense)             multiple                  8256      
_________________________________________________________________
dense_13 (Dense)             multiple                  2080      
_________________________________________________________________
dense_14 (Dense)             multiple                  330       
Total params: 244,522
Trainable params: 244,522
Non-trainable params: 0
____________________________________________________

## 데이터 셋 준비

In [None]:
train_data = tf.data.Dataset.from_tensor_slices((x_train, y_train)).shuffle(1000).batch(32)
valid_data = tf.data.Dataset.from_tensor_slices((x_valid, y_valid)).shuffle(1000).batch(32)

## 학습 방법 (Train)

### optimizer와 loss_function 정의

In [None]:
optimizer = tf.keras.optimizers.Adam()
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()

### (기록을 위한) Metric 정의

In [None]:
train_loss = tf.keras.metrics.Mean(name='train_loss')
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='train_accuracy')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_accuracy = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_accuracy')

### train_step 정의

In [None]:
@tf.function
def train_step(images, labels):
    # GradientTape 적용
    with tf.GradientTape() as tape:
        # 예측
        prediction = model(images, training=True)
        # 손실
        loss = loss_function(labels, prediction)
    # 미분 (gradient) 값 계산
    gradients = tape.gradient(loss, model.trainable_variables)
    # optimizer 적용
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))

    # loss, accuracy 계산
    train_loss(loss)
    train_accuracy(labels, prediction)

In [None]:
@tf.function
def valid_step(images, labels):
    # 예측
    prediction = model(images, training=False)    
    # 손실
    loss = loss_function(labels, prediction)

    # loss, accuracy 계산
    valid_loss(loss)
    valid_accuracy(labels, prediction)

### 학습 (train)

In [None]:
# 초기화 코드
train_loss.reset_states()
train_accuracy.reset_states()
valid_loss.reset_states()
valid_accuracy.reset_states()

# Epoch 반복
for epoch in range(10):
    # batch 별 순회
    for images, labels in train_data:
        # train_step
        train_step(images, labels)    

    for images, labels in valid_data:
        # valid_step
        valid_step(images, labels)

    # 로그 출력
    template = 'epoch: {}, loss: {:.3f}, acc: {:.3f}, val_loss: {:.3f}, val_acc: {:.3f}'
    print(template.format(epoch+1, train_loss.result(), train_accuracy.result()*100, valid_loss.result(), valid_accuracy.result()*100))

epoch: 1, loss: 0.026, acc: 99.335, val_loss: 0.122, val_acc: 97.750
epoch: 2, loss: 0.028, acc: 99.333, val_loss: 0.130, val_acc: 97.815
epoch: 3, loss: 0.027, acc: 99.342, val_loss: 0.128, val_acc: 97.880
epoch: 4, loss: 0.026, acc: 99.369, val_loss: 0.138, val_acc: 97.840
epoch: 5, loss: 0.028, acc: 99.359, val_loss: 0.138, val_acc: 97.838
epoch: 6, loss: 0.027, acc: 99.368, val_loss: 0.138, val_acc: 97.822
epoch: 7, loss: 0.027, acc: 99.387, val_loss: 0.138, val_acc: 97.839
epoch: 8, loss: 0.026, acc: 99.396, val_loss: 0.139, val_acc: 97.845
epoch: 9, loss: 0.026, acc: 99.405, val_loss: 0.146, val_acc: 97.849
epoch: 10, loss: 0.026, acc: 99.413, val_loss: 0.149, val_acc: 97.836
