In [None]:
# 모듈 임포트
import tensorflow as tf
import numpy as np

## 데이터셋 준비

In [None]:
(x_train,y_train),(x_valid,y_valid) = tf.keras.datasets.mnist.load_data()

x_train = tf.cast(x_train[...,tf.newaxis],dtype=tf.float32)/255.0
x_valid = tf.cast(x_valid[...,tf.newaxis],dtype=tf.float32)/255.0

train_data = tf.data.Dataset.from_tensor_slices((x_train,y_train)).shuffle(1000).batch(128)
valid_data = tf.data.Dataset.from_tensor_slices((x_valid,y_valid)).batch(32)

Downloading data from https://storage.googleapis.com/tensorflow/tf-keras-datasets/mnist.npz


## 모델링

### Residual Unit 구현

In [None]:
# Pre-Activation 구조
class ResidualUnit(tf.keras.models.Model):
    def __init__(self, filter_in, filter_out, kernel_size):
        super(ResidualUnit, self).__init__()
        self.bn1 = tf.keras.layers.BatchNormalization()
        self.conv1 = tf.keras.layers.Conv2D(filter_out,kernel_size,padding='same')
        
        self.bn2 = tf.keras.layers.BatchNormalization()
        self.conv2 = tf.keras.layers.Conv2D(filter_out,kernel_size,padding='same')
        
        if filter_in == filter_out:
            self.identity = lambda x:x
        else:
            self.identity = tf.keras.layers.Conv2D(filter_out,(1,1),padding='same')

    def call(self, input_, training=False):
        x = self.bn1(input_,training=training)
        x = tf.nn.relu(x)
        x = self.conv1(x)
        
        x = self.bn2(x,training=training)
        x = tf.nn.relu(x)
        x = self.conv2(x)
        
        return self.identity(input_) + x

### Residual Layer 구현

In [None]:
class ResidualLayer(tf.keras.models.Model):
    def __init__(self,filter_in,filters,kernel_size): 
        super(ResidualLayer,self).__init__()
        self.sequence = []
        for f_in,f_out in zip([filter_in]+list(filters),filters):
            self.sequence.append(ResidualUnit(f_in,f_out,kernel_size))
    
    def call(self,x,training=False):
        for unit in self.sequence:
            x = unit(x,training=training)
        return x

### 모델 정의

In [None]:
class ResNet(tf.keras.models.Model):
    def __init__(self):
        super(ResNet,self).__init__()
        self.conv1 = tf.keras.layers.Conv2D(8,(3,3),padding='same',activation='relu') # 처음에 한번 피쳐를 뽑고 Residual Layer를사용해야 효과적이다.

        self.res1 = ResidualLayer(8,(16,16),(3,3))
        self.maxpool1 = tf.keras.layers.MaxPool2D(2,2)

        self.res2 = ResidualLayer(16,(32,32),(3,3))
        self.maxpool2 = tf.keras.layers.MaxPool2D(2,2)

        self.res3 = ResidualLayer(32,(64,64),(3,3))

        self.flatten = tf.keras.layers.Flatten()
        self.dense = tf.keras.layers.Dense(128,activation='relu')
        self.output_ = tf.keras.layers.Dense(10,activation='softmax')
            

    def call(self,input_,training=False):
        x = self.conv1(input_)
        x = self.res1(x,training=training)
        x = self.maxpool1(x)
        x = self.res2(x,training=training)
        x = self.maxpool2(x)
        x = self.res3(x,training=training)
        x = self.flatten(x)
        x = self.dense(x)
        x = self.output_(x)
        return x

In [None]:
model = ResNet()
input_ = tf.keras.layers.Input(shape=(28,28,1))
model(input_)

model.summary()

Model: "res_net"
_________________________________________________________________
Layer (type)                 Output Shape              Param #   
conv2d (Conv2D)              multiple                  80        
_________________________________________________________________
residual_layer (ResidualLaye multiple                  8496      
_________________________________________________________________
max_pooling2d (MaxPooling2D) multiple                  0         
_________________________________________________________________
residual_layer_1 (ResidualLa multiple                  33376     
_________________________________________________________________
max_pooling2d_1 (MaxPooling2 multiple                  0         
_________________________________________________________________
residual_layer_2 (ResidualLa multiple                  132288    
_________________________________________________________________
flatten (Flatten)            multiple                  0   

## 모델 컴파일,학습

In [None]:
model.compile(optimizer='adam',loss='sparse_categorical_crossentropy',metrics = ['acc'])

In [None]:
model.fit(train_data,
          validation_data=(valid_data),
          epochs = 20)

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20
Epoch 19/20
Epoch 20/20


<tensorflow.python.keras.callbacks.History at 0x7f6b06832b10>