In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
from keras.models import Model, Sequential
from keras.layers import Dense, MaxPool2D, AvgPool2D, Conv2D, BatchNormalization, Softmax, ReLU
from keras.optimizers import RMSprop
from keras.losses import SparseCategoricalCrossentropy
from keras.callbacks import ModelCheckpoint
from keras.datasets import 

In [2]:
class ResBlock(Model):
    def __init__(self, conv_count=2, filters=64, strides=1):
        super().__init__()

        if isinstance(filters, int):
            filters = [filters] * conv_count
        elif isinstance(filters, tuple) or isinstance(filters, list):
            if len(filters) != conv_count:
                raise ValueError('len(filters) not equals conv_count.')
        else:
            raise TypeError('Unsupported value type for filters.')

        if isinstance(strides, int):
            strides = [strides] * conv_count
        elif isinstance(strides, tuple) or isinstance(strides, list):
            if len(strides) != conv_count:
                raise ValueError('len(strides) not equals conv_count.')
        else:
            raise TypeError('Unsupported value type for strides.')

        self.convs = list()
        for i in range(conv_count):
            self.convs.append(Conv2D(filters=filters[i], 
                                     kernel_size=3, 
                                     strides=strides[i]))
        self.activations = list()
        for i in range(conv_count):
            self.activations.append(ReLU())
        self.bn = list()
        for i in range(conv_count):
            self.bn.append(BatchNormalization())
    
    def call(self, x):
        y = x
        for layer in range(self.layers):
            y = self.convs[layer](y)
            y = self.activations[layer](y)
            y = self.bn[layer](y)
        y += x
        return y

In [None]:
class ConvBlock(Model):
    def __init__(self, resblock_count=3, filters=64, first_layer=False):
        super().__init__()
        self.blocks = list()
        for i in range(resblock_count):
            if not first_layer and i == 0:
                self.blocks.append(ResBlock(filters=filters, strides=(2,1)))
            else:
                self.blocks.append(ResBlock(filters=filters, strides=1))
    
    def call(self, x):
        for block in self.blocks:
            x = block(x)
        return x

In [None]:
class ResNet(Model):
    def __init__(self):
        super().__init__()
        self.conv1 = Conv2D(filters=64, kernel_size=7, strides=2, padding='same')
        self.pool1 = MaxPool2D(pool_size=3, strides=2)
        self.conv2_x = ConvBlock(resblock_count=2, filters=64, first_layer=True)
        self.conv3_x = ConvBlock(resblock_count=2, filters=128)
        self.conv4_x = ConvBlock(resblock_count=2, filters=256)
        self.conv5_x = ConvBlock(resblock_count=2, filters=512)
        self.pool2 = AvgPool2D()
        self.dense = Dense(10)
        self.softmax = Softmax()

    def call(self, x):
        x = self.conv1(x)
        x = self.pool1(x)
        x = self.conv2_x(x)
        x = self.conv3_x(x)
        x = self.conv4_x(x)
        x = self.conv5_x(x)
        x = self.pool2(x)
        x = self.dense(x)
        x = self.softmax(x)
        return x