In [11]:
import tensorflow as tf
from   tensorflow.keras.layers import (
    Activation, Add, BatchNormalization, Conv2D, Dense, 
    GlobalAveragePooling2D, Layer, MaxPool2D)
from   tensorflow.keras import Model
import tensorflow_datasets as tfds

In [17]:
class IdentityBlock(Model):
    def __init__(self, filters, kernel_size, name=''):
        super().__init__(name=name)
        self.conv1 = Conv2D(filters, kernel_size, padding='same')
        self.bn1 = BatchNormalization()
        self.conv2 = Conv2D(filters, kernel_size, padding='same')
        self.bn2 = BatchNormalization()
        self.relu = Activation('relu')
        self.add = Add()
        
    def call(self, input_tensor):
        x = self.conv1(input_tensor)
        x = self.bn1(x)
        x = self.relu(x)
        x = self.conv2(x)
        x = self.bn2(x)
        x = self.relu(x)
        x = self.add([x, input_tensor])
        x = self.relu(x)
        return x

In [19]:
class ResNet(Model):
    def __init__(self, n_classes):
        super().__init__()
        self.conv = Conv2D(64, 7, padding='same')
        self.bn = BatchNormalization()
        self.relu = Activation('relu')
        self.pool = MaxPool2D((3, 3))
        self.id1a = IdentityBlock(64, 3)
        self.id1b = IdentityBlock(64, 3)
        self.global_pool = GlobalAveragePooling2D()
        self.classifier = Dense(n_classes, activation='softmax')
        
    def call(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.relu(x)
        x = self.pool(x)
        x = self.id1a(x)
        x = self.id1b(x)
        x = self.global_pool(x)
        return self.classifier(x)

In [20]:
def preprocess(features):
    return (tf.cast(features['image'], tf.float32) / 255., 
            features['label'])

In [21]:
resnet = ResNet(10)
resnet.compile(optimizer='adam', 
               loss='sparse_categorical_crossentropy',
               metrics=['accuracy'])

In [22]:
data_set = tfds.load('mnist', split=tfds.Split.TRAIN)
data_set = data_set.map(preprocess).batch(32)
resnet.fit(data_set, epochs=3)

Epoch 1/3
Epoch 2/3
Epoch 3/3


<tensorflow.python.keras.callbacks.History at 0x7fe9ca9111c0>