In [1]:
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import tensorflow as tf
import tensorflow_datasets as tfds
from tensorflow import keras
from tensorflow.keras import layers
from keras import callbacks

In [2]:
batch_size = 32
img_height = 32
img_width = 32

IMG_SIZE = 32

In [3]:
(x_train, y_train), (x_test, y_test) = keras.datasets.cifar10.load_data()

x_train = x_train/255
x_test = x_test/255

num_classes = np.size(np.unique(y_train))

y_train = tf.keras.utils.to_categorical(y_train)
y_test = tf.keras.utils.to_categorical(y_test)

Downloading data from https://www.cs.toronto.edu/~kriz/cifar-10-python.tar.gz


In [4]:
base_model = keras.applications.MobileNet(
                        include_top=False, # Exclude ImageNet classifier at the top
                        weights='imagenet',
                        input_shape=(IMG_SIZE, IMG_SIZE, 3)
                        )

# Freeze base_model
base_model.trainable = False

# Setup inputs based on input image shape
inputs = keras.Input(shape=(IMG_SIZE, IMG_SIZE, 3))

data_augmentation = keras.Sequential(
                [layers.RandomFlip('horizontal')])

x = data_augmentation(inputs)

# Apply specific pre-processing function for ResNet v2
x = keras.applications.mobilenet.preprocess_input(x)

# Keep base model batch normalization layers in inference mode (instead of training mode)
x = base_model(x, training=False)

# Rebuild top layers
x = layers.GlobalAveragePooling2D()(x) # Average pooling operation
x = layers.BatchNormalization()(x) # Introduce batch norm
x = layers.Dropout(0.2)(x)  # Regularize with dropout

# Flattening to final layer - Dense classifier with 37 units (multi-class classification)
outputs = layers.Dense(num_classes, activation='softmax')(x)

# Instantiate final Keras model with updated inputs and outputs
model = keras.Model(inputs, outputs)



In [None]:
model.compile(optimizer=keras.optimizers.Adam(),
              loss=keras.losses.CategoricalCrossentropy(),
              metrics=[keras.metrics.CategoricalAccuracy()]
              )

earlystopping = callbacks.EarlyStopping(monitor='val_loss',
                                        mode='min',
                                        patience=5,
                                        restore_best_weights=True)

EPOCHS = 25

history = model.fit(x_train, y_train,
                    epochs=EPOCHS,
                    validation_data=(x_test, y_test),
                    verbose=1,
                    callbacks =[earlystopping])

Epoch 1/25
Epoch 2/25
Epoch 3/25
Epoch 4/25
Epoch 5/25
Epoch 6/25
Epoch 7/25

In [None]:
def plot_hist(hist):
    plt.plot(hist.history['categorical_accuracy'])
    plt.plot(hist.history['val_categorical_accuracy'])
    plt.title('Categorical accuracy')
    plt.ylabel('accuracy')
    plt.xlabel('epoch')
    plt.legend(['train', 'validation'], loc='upper left')
    plt.show()

plot_hist(history)