In [1]:
import os
from typing import Tuple

import numpy as np
import tensorflow as tf
from tensorflow.keras.callbacks import EarlyStopping
from tensorflow.keras.layers import Activation
from tensorflow.keras.layers import Add
from tensorflow.keras.layers import BatchNormalization
from tensorflow.keras.layers import Conv2D
from tensorflow.keras.layers import Dense
from tensorflow.keras.layers import GlobalAveragePooling2D
from tensorflow.keras.layers import Input
from tensorflow.keras.layers import MaxPool2D
from tensorflow.keras.models import Model
from tensorflow.keras.optimizers import Adam

from tf_utils.mnistData_advance import MNIST

In [2]:
np.random.seed(0)
tf.random.set_seed(0)

In [4]:
LOGS_DIR = os.path.abspath("C:/Selbststudium/Udemy/Udemy_Tensorflow/logs")
if not os.path.exists(LOGS_DIR):
    os.mkdir(LOGS_DIR)

In [6]:
def relu_norm(x: tf.Tensor) -> tf.Tensor:
    x = Activation("relu")(x)
    x = BatchNormalization()(x)
    return x

In [7]:
def residual_block(
    x: tf.Tensor,
    filters: int,
    downsample: bool = False,
) -> tf.Tensor:
    y = Conv2D(
        kernel_size=3,
        strides=(1 if not downsample else 2),
        filters=filters,
        padding="same",
    )(x)
    y = relu_norm(y)
    y = Conv2D(
        kernel_size=3,
        strides=1,
        filters=filters,
        padding="same",
    )(y)

    if downsample:  # H, W changed
        x = Conv2D(
            kernel_size=1,
            strides=2,
            filters=filters,
            padding="same",
        )(x)
    elif x.shape[-1] != filters:  # Channels changed
        x = Conv2D(
            kernel_size=1,
            strides=1,
            filters=filters,
            padding="same",
        )(x)
    out = Add()([x, y])
    out = relu_norm(out)
    return out

In [8]:
def output_block(x: tf.Tensor, num_classes: int) -> tf.Tensor:
    x = GlobalAveragePooling2D()(x)
    x = Dense(
        units=num_classes,
    )(x)
    x = Activation("softmax")(x)
    return x

In [9]:
def build_model_resnet(
    img_shape: Tuple[int, int, int],
    num_classes: int,
) -> Model:
    input_img = Input(shape=img_shape)

    x = residual_block(x=input_img, filters=32, downsample=True)
    x = residual_block(x=x, filters=64, downsample=False)
    x = residual_block(x=x, filters=64, downsample=False)
    x = residual_block(x=x, filters=128, downsample=True)
    x = residual_block(x=x, filters=128, downsample=False)
    y_pred = output_block(x=x, num_classes=num_classes)

    model = Model(inputs=[input_img], outputs=[y_pred])

    opt = Adam()

    model.compile(
        loss="categorical_crossentropy", optimizer=opt, metrics=["accuracy"]
    )

    return model

In [10]:
data = MNIST()

train_dataset = data.get_train_set()
val_dataset = data.get_val_set()
test_dataset = data.get_test_set()

img_shape = data.img_shape
num_classes = data.num_classes

In [11]:
epochs = 100
batch_size = 128

In [12]:
model = build_model_resnet(
    img_shape,
    num_classes,
)

In [13]:
model.summary()

Model: "model"
__________________________________________________________________________________________________
 Layer (type)                   Output Shape         Param #     Connected to                     
 input_1 (InputLayer)           [(None, 28, 28, 1)]  0           []                               
                                                                                                  
 conv2d (Conv2D)                (None, 14, 14, 32)   320         ['input_1[0][0]']                
                                                                                                  
 activation (Activation)        (None, 14, 14, 32)   0           ['conv2d[0][0]']                 
                                                                                                  
 batch_normalization (BatchNorm  (None, 14, 14, 32)  128         ['activation[0][0]']             
 alization)                                                                                   

In [14]:
model_log_dir = os.path.join(LOGS_DIR, "model_resnet_mnist")

In [None]:
es_callback = EarlyStopping(
    monitor="val_accuracy",
    patience=30,
    verbose=1,
    restore_best_weights=True,
    min_delta=0.0005
)

model.fit(
    train_dataset,
    verbose=1,
    batch_size=batch_size,
    epochs=epochs,
    callbacks=[es_callback],
    validation_data=val_dataset,
)

In [None]:
scores = model.evaluate(
    val_dataset,
    verbose=0,
    batch_size=batch_size
)
print(f"Scores: {scores}")