In [1]:
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd

Set up TensorBoard

In [2]:
import os

root_logdir = os.path.join(os.curdir, "my_logs")

def get_run_logdir():
    import time
    run_id = time.strftime("run_%Y_%m_%d-%H_%M_%S")
    return os.path.join(root_logdir, run_id)

run_logdir = get_run_logdir() # e.g., './my_logs/run_2019_06_07-15_15_22'


In [3]:
tf.debugging.experimental.enable_dump_debug_info(run_logdir, tensor_debug_mode="FULL_HEALTH", circular_buffer_size=-1)

INFO:tensorflow:Enabled dumping callback in thread MainThread (dump root: .\my_logs\run_2022_10_02-11_13_32, tensor debug mode: FULL_HEALTH)


<tensorflow.python.debug.lib.debug_events_writer.DebugEventsWriter at 0x2b7f9eab190>

In [None]:
%load_ext tensorboard
%tensorboard --logdir=./my_logs --port=6006

Reusing TensorBoard on port 6006 (pid 11680), started 0:03:24 ago. (Use '!kill 11680' to kill it.)

Load CIFAR100 dataset

In [46]:
(x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar100.load_data(label_mode="coarse")


y_train = tf.keras.utils.to_categorical(y_train, num_classes=20)
y_test = tf.keras.utils.to_categorical(y_test, num_classes=20)


In [16]:
x_train.shape # size, height, width, channels

(50000, 32, 32, 3)

In [47]:
y_train.shape # one-hot encoded

(50000, 20)

1. Simple CNN architecture

The idea is to tackle the superclass classification first, and then use this model to help a second model learning the fine classification.



In [48]:
model = tf.keras.models.Sequential()

# Input layer
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=3,
                                 padding="same", activation="relu",
                                 input_shape=x_train.shape[1:]))
model.add(tf.keras.layers.Conv2D(filters=64, kernel_size=3,
                                 padding="same", activation="relu"))
                                 
model.add(tf.keras.layers.MaxPool2D(2))                                 
model.add(tf.keras.layers.BatchNormalization())

# Conv2D layer
model.add(tf.keras.layers.Conv2D(filters=128, kernel_size=3,
                                 padding="same", activation="relu"))

model.add(tf.keras.layers.Conv2D(filters=128, kernel_size=3,
                                 padding="same", activation="relu"))

model.add(tf.keras.layers.MaxPool2D(2))
model.add(tf.keras.layers.BatchNormalization())

# Fully-connected layer
model.add(tf.keras.layers.Flatten())
model.add(tf.keras.layers.Dense(512, activation="relu"))
model.add(tf.keras.layers.Dropout(rate=0.4))
model.add(tf.keras.layers.Dense(1024, activation="relu"))
model.add(tf.keras.layers.Dropout(rate=0.4))

# Output layer
model.add(tf.keras.layers.Dense(20, activation="softmax"))

In [49]:
n_epochs = 100
batch_size = 32

checkpoint_cb = tf.keras.callbacks.ModelCheckpoint(r"./cifar100-checkpoints/",
                                                   monitor="val_loss",
                                                   save_freq="epoch")

early_stopping_cb = tf.keras.callbacks.EarlyStopping(monitor="val_loss",
                                                     patience=5,
                                                     min_delta=1e-4)  

# Build and compile your model
tensorboard_cb = tf.keras.callbacks.TensorBoard(run_logdir)

# Schedule learning rate update during training
def scheduler(epoch, lr):
  if epoch < 10:
    return lr
  else:
    return lr * tf.math.exp(-0.05)

rate_scheduler = tf.keras.callbacks.LearningRateScheduler(scheduler) 

callbacks = [early_stopping_cb, checkpoint_cb, rate_scheduler, tensorboard_cb]

In [52]:
adam = tf.keras.optimizers.Adam(learning_rate=1e-3)
accuracy_score = tf.keras.metrics.Accuracy()

model.compile(optimizer=adam, loss="categorical_crossentropy",
              metrics=[accuracy_score])

In [53]:
history = model.fit(x_train, y_train,
                    batch_size=batch_size, epochs=n_epochs,
                    callbacks=callbacks,
                    validation_split=0.1)

Epoch 1/100
  78/1407 [>.............................] - ETA: 16:46 - loss: 3.3658 - accuracy: 0.0000e+00

KeyboardInterrupt: 