In [62]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
from tensorflow.keras.layers import Conv2D, BatchNormalization, Activation, GlobalAveragePooling2D, Dense, MaxPool2D
from tensorflow.keras.models import Model

In [63]:
class ResUnit(Model):
    def __init__(self, filter_in, filter_out, kernel_size):
        super(ResUnit, self).__init__()

        self.sequence = list()

        # Resnet 특유의 덧셈 연산을 위해 입력의 차원(depth, channel)을 맞춰주는 연산
        self.identity = Conv2D(filter_out, (1, 1), padding='valid')

        # Bottleneck(down sampling)
        self.sequence.append(Conv2D(filter_in, (1, 1), padding='valid'))

        # Conv
        self.sequence.append(Conv2D(filter_in, kernel_size, padding='same'))

        # Bottleneck(up sampling)
        self.sequence.append(Conv2D(filter_out, (1, 1), padding='valid'))

        # BN, Activation
        self.sequence.append(BatchNormalization())
        self.sequence.append(Activation('relu'))

    def __call__(self, images, training):
        # Downsampling -> Conv -> Upsampling -> BN -> Activation -> Add
        h = images
        for unit in self.sequence:
            if isinstance(unit, BatchNormalization):
                h = unit(h, training=training)
            else:
                h = unit(h)

        # Add
        return self.identity(images) + h

In [64]:
class ResLayer(Model):
    def __init__(self, filter_in, filter_out, kernel_size, iter_count):
        super(ResLayer, self).__init__()

        self.sequence = list()

        # ResUnit 을 iter_count 개수만큼 쌓아올림
        # https://eremo2002.tistory.com/76 을 참고하여 ResNet의 논문 커널의 수를 맞춤
        for i in range(iter_count):
            self.sequence.append(ResUnit(filter_in, filter_out, kernel_size))

    def __call__(self, images, training):
        for layer in self.sequence:
            images = layer(images, training)
        return images

In [65]:
class ResNet(Model):
    def __init__(self, output_size):
        super(ResNet, self).__init__()

        # ResNet 모델
        self.sequence = list()
        self.sequence.append(Conv2D(64, (7, 7), (2, 2), padding='same'))
        self.sequence.append(MaxPool2D((3, 3), (2, 2)))
        self.sequence.append(ResLayer(64, 256, (3, 3), 3))
        self.sequence.append(ResLayer(128, 512, (3, 3), 4))
        self.sequence.append(ResLayer(256, 1024, (3, 3), 6))
        self.sequence.append(ResLayer(512, 2048, (3, 3), 3))
        self.sequence.append(GlobalAveragePooling2D())
        self.sequence.append(Dense(output_size, activation='softmax'))

    def __call__(self, images, training):
        for layer in self.sequence:
            if isinstance(layer, ResLayer):
                images = layer(images, training)
            else:
                images = layer(images)
        return images

In [66]:
# ResNet 테스트 코드.
# mnist 를 학습하여 맞추는 모델로, 97%의 정확도를 지님
'''
BATCH_SIZE = 32
mnist = tf.keras.datasets.mnist
(train_x, train_y), (test_x, test_y) = mnist.load_data()
train_x = train_x[..., tf.newaxis].astype(np.float32)
test_x = test_x[..., tf.newaxis].astype(np.float32)

train_ds = tf.data.Dataset.from_tensor_slices((train_x, train_y)).shuffle(1000).batch(BATCH_SIZE)
test_ds = tf.data.Dataset.from_tensor_slices((test_x, test_y)).batch(BATCH_SIZE)
'''

In [67]:
'''
model = ResNet(10)
loss_object = tf.keras.losses.SparseCategoricalCrossentropy()
optimizer = tf.keras.optimizers.Adam()

train_loss = tf.keras.metrics.Mean("train_loss")
train_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("train_accuracy")

test_loss = tf.keras.metrics.Mean("test_loss")
test_accuracy = tf.keras.metrics.SparseCategoricalAccuracy("test_accuracy")
'''

In [68]:
'''
@tf.function
def train_step(model, images, labels, loss_object, optimizer, train_loss, train_accuracy):
    with tf.GradientTape() as tape:
        predictions = model(images, True)
        loss = loss_object(labels, predictions)
    gradients = tape.gradient(loss, model.trainable_variables)
    optimizer.apply_gradients(zip(gradients, model.trainable_variables))
    train_loss(loss)
    train_accuracy(labels, predictions)

@tf.function
def test_step(model, images, labels, loss_object, test_loss, test_accuracy):
    predictions = model(images, False)
    loss = loss_object(labels, predictions)
    test_loss(loss)
    test_accuracy(labels, predictions)
'''

In [70]:
'''
EPOCHS = 10
for epoch in range(EPOCHS):
    for images, labels in train_ds:
        train_step(model, images, labels, loss_object, optimizer, train_loss, train_accuracy)

    for images, labels in test_ds:
        test_step(model, images, labels, loss_object, test_loss, test_accuracy)

    print('Epoch {}, Test loss : {}, Test accuracy: {}'.format(epoch + 1, test_loss.result(), test_accuracy.result()))
'''

Epoch 1, Test loss : 0.24728307127952576, Test accuracy: 0.9351181983947754
Epoch 2, Test loss : 0.24831706285476685, Test accuracy: 0.9345999956130981
Epoch 3, Test loss : 0.23544898629188538, Test accuracy: 0.9376461505889893
Epoch 4, Test loss : 0.22442807257175446, Test accuracy: 0.9403785467147827
Epoch 5, Test loss : 0.21675211191177368, Test accuracy: 0.9419866800308228
Epoch 6, Test loss : 0.20636604726314545, Test accuracy: 0.9445499777793884
Epoch 7, Test loss : 0.19734534621238708, Test accuracy: 0.9469352960586548
Epoch 8, Test loss : 0.18974082171916962, Test accuracy: 0.9487666487693787
Epoch 9, Test loss : 0.1818980723619461, Test accuracy: 0.9507052898406982
Epoch 10, Test loss : 0.17549504339694977, Test accuracy: 0.9523299932479858
