In [1]:
# 모듈 임포트
import tensorflow as tf
import numpy as np

## 데이터셋 준비

In [2]:
(x_train,y_train),(x_valid,y_valid) = tf.keras.datasets.mnist.load_data()

x_train = x_train[...,tf.newaxis].astype(np.float32)/255.0
x_valid = x_valid[...,tf.newaxis].astype(np.float32)/255.0

train_data = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(1000).batch(128)
valid_data = tf.data.Dataset.from_tensor_slices((x_valid,y_valid)).batch(32)

## Vanilla CNN 모델링

In [3]:
class ConvNet(tf.keras.models.Model):
    def __init__(self):
        super(ConvNet,self).__init__()
        self.cnn1 = tf.keras.layers.Conv2D(16,(3,3),padding='same',activation='relu')
        self.cnn2 = tf.keras.layers.Conv2D(16,(3,3),padding='same',activation='relu')
        self.maxpool1 = tf.keras.layers.MaxPool2D(2,2)
        self.cnn3 = tf.keras.layers.Conv2D(32,(3,3),padding='same',activation='relu')
        self.cnn4 = tf.keras.layers.Conv2D(32,(3,3),padding='same',activation='relu')
        self.maxpool2 = tf.keras.layers.MaxPool2D(2,2)
        self.cnn5 = tf.keras.layers.Conv2D(32,(3,3),padding='same',activation='relu')
        self.cnn6 = tf.keras.layers.Conv2D(32,(3,3),padding='same',activation='relu')
        self.flatten = tf.keras.layers.Flatten()
        self.Dense = tf.keras.layers.Dense(128,activation='relu')
        self.output_ = tf.keras.layers.Dense(10,activation='softmax')

    def call(self,input_):
        x = self.cnn1(input_)
        x = self.cnn2(x)
        x = self.maxpool1(x)
        x = self.cnn3(x)
        x = self.cnn4(x)
        x = self.maxpool2(x)
        x = self.cnn5(x)
        x = self.cnn6(x)
        x = self.flatten(x)
        x = self.Dense(x)
        x = self.output_(x)
        return x


In [4]:
model = ConvNet()
input_ = tf.keras.layers.Input(shape=(28,28,1))
model(input_)

model.summary()

Model: "conv_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  160       
_________________________________________________________________
conv2d_1 (Conv2D)            multiple                  2320      
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
conv2d_2 (Conv2D)            multiple                  4640      
_________________________________________________________________
conv2d_3 (Conv2D)            multiple                  9248      
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple                  0         
_________________________________________________________________
conv2d_4 (Conv2D)            multiple                  924

## GradientTape 커스텀

In [5]:
# loss,optimizer,accuracy 정의 loss,optimizer,accuracy 정의
loss_function = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean(name='train_loss')
train_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='train_acc')
valid_loss = tf.keras.metrics.Mean(name='valid_loss')
valid_acc = tf.keras.metrics.SparseCategoricalAccuracy(name='valid_acc')

In [6]:
# 학습 루프 정의
@tf.function
def train_step(image,label):
    with tf.GradientTape() as tape:
        prediction = model(image,training=True)
        loss = loss_function(label,prediction)
    
    gradients = tape.gradient(loss,model.trainable_variables)
    optimizer.apply_gradients(zip(gradients,model.trainable_variables))

    train_loss(loss)
    train_acc(label,prediction)

In [7]:
#검증 루프 정의
@tf.function
def valid_step(image,label):
    prediction = model(image,training=False)
    loss = loss_function(label,prediction)

    valid_loss(loss)
    valid_acc(label,prediction)

## 학습 루프 동작

In [8]:
EPOCHS = 20

train_loss.reset_states()
train_acc.reset_states()
valid_loss.reset_states()
valid_acc.reset_states()

for epoch in range(EPOCHS):
    
    for image,label in train_data:
        train_step(image,label)
    
    for image,label in valid_data:
        valid_step(image,label)
    
    template = 'epoch: {}, loss: {}, acc: {}, val_loss: {}, val_acc: {}'
    print(template.format(epoch+1,train_loss.result(),train_acc.result(),valid_loss.result(),valid_acc.result()))

epoch: 1, loss: 0.20778584480285645, acc: 0.9350000023841858, val_loss: 0.0596526563167572, val_acc: 0.98089998960495
epoch: 2, loss: 0.12918813526630402, acc: 0.9595000147819519, val_loss: 0.054524291306734085, val_acc: 0.9828000068664551
epoch: 3, loss: 0.09789904952049255, acc: 0.9693111181259155, val_loss: 0.04951520264148712, val_acc: 0.9842333197593689
epoch: 4, loss: 0.080332912504673, acc: 0.9747666716575623, val_loss: 0.045701347291469574, val_acc: 0.9856250286102295
epoch: 5, loss: 0.06873320043087006, acc: 0.9783700108528137, val_loss: 0.04499353840947151, val_acc: 0.9859399795532227
epoch: 6, loss: 0.06035727262496948, acc: 0.9809499979019165, val_loss: 0.042893193662166595, val_acc: 0.9866999983787537
epoch: 7, loss: 0.05380138009786606, acc: 0.9829738140106201, val_loss: 0.04241083189845085, val_acc: 0.9871000051498413
epoch: 8, loss: 0.04893072322010994, acc: 0.9844833612442017, val_loss: 0.04200979322195053, val_acc: 0.9871500134468079
epoch: 9, loss: 0.0446754172444343