In [1]:
import tensorflow as tf

from tensorflow.keras import datasets, layers, models
import matplotlib.pyplot as plt

import tensorflow_addons as tfa

In [2]:
(train_images, train_labels), (test_images, test_labels) = datasets.cifar10.load_data()

train_images, test_images = train_images / 255.0, test_images / 255.0

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [50]:
def get_model():
    model = tf.keras.applications.ResNet50_V2(
        include_top=True,
        weights=None,
        input_tensor=None,
        input_shape=train_images.shape[1:],
        pooling=None,
        classes=10,
        classifier_activation="softmax",
    )

    return model

In [51]:
def plot_acc_loss(history):
    plt.plot(history.history['accuracy'], label='accuracy')
    plt.plot(history.history['val_accuracy'], label = 'val accuracy')
    plt.xlabel('Epoch')
    plt.ylabel('Accuracy')
    plt.legend(loc='lower right')
    plt.show()
    plt.close()

    plt.plot(history.history['loss'], label='train loss')
    plt.plot(history.history['val_loss'], label = 'val loss')
    plt.xlabel('Epoch')
    plt.ylabel('Loss')
    plt.legend(loc='lower right')
    plt.show()
    plt.close()  

In [52]:
adam_model = get_model()

adam_model.compile(optimizer='adam', loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                   metrics=['accuracy'])

history_adam = adam_model.fit(train_images, train_labels, batch_size=512, 
                         epochs=20, validation_data=(test_images, test_labels))

Epoch 1/20
Epoch 2/20
Epoch 3/20
Epoch 4/20
Epoch 5/20
Epoch 6/20
Epoch 7/20
Epoch 8/20
Epoch 9/20
Epoch 10/20
Epoch 11/20
Epoch 12/20
Epoch 13/20
Epoch 14/20
Epoch 15/20
Epoch 16/20
Epoch 17/20
Epoch 18/20

KeyboardInterrupt: ignored

In [None]:
plot_acc_loss(history_adam)

In [None]:
lookahead_model = get_model()

lookahead = tfa.optimizers.Lookahead(optimizer = tf.keras.optimizers.Adam(), sync_period=5, slow_step_size=0.6)
lookahead_model.compile(optimizer=lookahead, loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True), 
                   metrics=['accuracy'])

history_lookahead = lookahead_model.fit(train_images, train_labels, batch_size=512, 
                              epochs=20, validation_data=(test_images, test_labels))

In [None]:
plot_acc_loss(history_lookahead)