In [1]:
import datetime
import os

import tensorboard
import tensorflow as tf

import tensorflow.keras.applications.resnet

from smot.jupyter import model_reports
from smot.problems.mnist import mnist_lib

In [2]:
(x_train, y_train), (x_test, y_test) = mnist_lib.load_mnist_data_28x28x1()

In [4]:
class IdentityBlock(tf.keras.Model):
    def __init__(self, filters, kernel_size=3):
        super(IdentityBlock, self).__init__(name="")

        self.conv1 = tf.keras.layers.Conv2D(
            filters=filters, kernel_size=kernel_size, padding="same"
        )
        self.bn1 = tf.keras.layers.BatchNormalization()

        self.conv2 = tf.keras.layers.Conv2D(
            filters=filters,
            kernel_size=kernel_size,
            padding="same",
        )
        self.bn2 = tf.keras.layers.BatchNormalization()

        self.act = tf.keras.layers.Activation("relu")
        self.add = tf.keras.layers.Add()

    def call(self, input_tensor):
        x = self.conv1(input_tensor)
        x = self.bn1(x)
        x = self.act(x)

        x = self.conv2(x)
        x = self.bn2(x)

        x = self.add([x, input_tensor])
        x = self.act(x)
        return x


class ResNetSmall(tf.keras.Model):
    def __init__(self, num_classes, *, input_shape):
        super().__init__()
        self.conv = tf.keras.layers.Conv2D(
            64, 7, input_shape=input_shape, padding="same"
        )
        self.bn = tf.keras.layers.BatchNormalization()
        self.act = tf.keras.layers.Activation("relu")
        self.max_pool = tf.keras.layers.MaxPool2D((3, 3))

        self.id1a = IdentityBlock(64, 3)
        self.id1b = IdentityBlock(64, 3)

        self.global_pool = tf.keras.layers.GlobalAveragePooling2D()
        self.classifier = tf.keras.layers.Dense(num_classes, activation="softmax")

    def call(self, inputs):
        x = self.conv(inputs)
        x = self.bn(x)
        x = self.act(x)
        x = self.max_pool(x)

        x = self.id1a(x)
        x = self.id1b(x)

        x = self.global_pool(x)
        return self.classifier(x)

strategy = tf.distribute.MirroredStrategy()

with strategy.scope():
  model = ResNetSmall(10, input_shape=mnist_lib.INPUT_SHAPE)

model.compile(
    optimizer=tf.keras.optimizers.Adam(),
    loss=tf.keras.losses.CategoricalCrossentropy(),
    metrics=["accuracy"],
)

# Print the model summary.
# model.summary()

INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0', '/job:localhost/replica:0/task:0/device:GPU:1')
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).


In [5]:
batch_size = 64 * strategy.num_replicas_in_sync

datagen = tf.keras.preprocessing.image.ImageDataGenerator(
    rotation_range=10,
    shear_range=0.1,
    width_shift_range=0.1,
    height_shift_range=0.1,
    zoom_range=0.2,
    validation_split=0.2,
)
datagen.fit(x_train)

training_generator = datagen.flow(
    x_train,
    y_train,
    subset="training",
    batch_size=batch_size,
)
validation_generator = datagen.flow(
    x_train,
    y_train,
    subset="validation",
    batch_size=batch_size,
)

logdir = os.path.join("logs", datetime.datetime.now().strftime("%Y%m%d-%H%M%S"))
tensorboard_callback = tf.keras.callbacks.TensorBoard(logdir, histogram_freq=1)

%load_ext tensorboard
%tensorboard --logdir logs


history = model.fit(
    training_generator,
    validation_data=validation_generator,
    epochs=200,
    verbose=1,
    use_multiprocessing=True,
    workers=24,
    callbacks=[
        tf.keras.callbacks.EarlyStopping(monitor="loss", patience=5),
        tensorboard_callback,
    ],
)

# Evaluate the model with the test data.
test_loss, test_accuracy = model_reports.model_fit_report(
    model=model,
    history=history,
    test_data=(x_test, y_test),
)

Launching TensorBoard...

Epoch 1/200
INFO:tensorflow:batch_all_reduce: 22 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:batch_all_reduce: 22 all-reduces with algorithm = nccl, num_packs = 1
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/

Process Keras_worker_ForkPoolWorker-1029:
Process Keras_worker_ForkPoolWorker-1014:
Process Keras_worker_ForkPoolWorker-1010:
Process Keras_worker_ForkPoolWorker-1023:
Process Keras_worker_ForkPoolWorker-1032:
Process Keras_worker_ForkPoolWorker-1017:
Process Keras_worker_ForkPoolWorker-1016:
Process Keras_worker_ForkPoolWorker-1028:
Process Keras_worker_ForkPoolWorker-1009:
Process Keras_worker_ForkPoolWorker-1030:
Process Keras_worker_ForkPoolWorker-1024:
Process Keras_worker_ForkPoolWorker-1020:
Process Keras_worker_ForkPoolWorker-1012:
Process Keras_worker_ForkPoolWorker-1013:
Process Keras_worker_ForkPoolWorker-1015:
Process Keras_worker_ForkPoolWorker-1018:
Process Keras_worker_ForkPoolWorker-1019:
Process Keras_worker_ForkPoolWorker-1027:
Process Keras_worker_ForkPoolWorker-1026:
Traceback (most recent call last):
Process Keras_worker_ForkPoolWorker-1025:
Process Keras_worker_ForkPoolWorker-1011:
Process Keras_worker_ForkPoolWorker-1031:
Process Keras_worker_ForkPoolWorker-1021:

KeyboardInterrupt: 