In [5]:
import numpy as np
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.utils import to_categorical
from tensorflow.keras.optimizers import Adam
from tensorflow.keras.callbacks import ModelCheckpoint, EarlyStopping
from tensorflow.keras.applications import VGG16, VGG19, ResNet50
from tensorflow.keras.models import Model
import efficientnet.tfkeras as efn
import matplotlib.pyplot as plt


In [2]:
(x_train, y_train), (x_test, y_test) = cifar10.load_data()

# Normalize the data
x_train = x_train.astype('float32') / 255.
x_test = x_test.astype('float32') / 255.

# One-hot encode the labels
y_train = to_categorical(y_train, 10)
y_test = to_categorical(y_test, 10)


In [3]:
def build_model(base_model, model_name):
    x = base_model.output
    x = tf.keras.layers.GlobalAveragePooling2D()(x)
    predictions = tf.keras.layers.Dense(10, activation='softmax', name='output')(x)
    model = Model(inputs=base_model.input, outputs=predictions, name=model_name)

    optimizer = Adam(lr=0.001)
    model.compile(optimizer=optimizer, loss='categorical_crossentropy', metrics=['accuracy'])

    return model


Metal device set to: Apple M1

systemMemory: 8.00 GB
maxCacheSize: 2.67 GB



2023-04-12 13:37:15.099319: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:305] Could not identify NUMA node of platform GPU ID 0, defaulting to 0. Your kernel may not have been built with NUMA support.
2023-04-12 13:37:15.099769: I tensorflow/core/common_runtime/pluggable_device/pluggable_device_factory.cc:271] Created TensorFlow device (/job:localhost/replica:0/task:0/device:GPU:0 with 0 MB memory) -> physical PluggableDevice (device: 0, name: METAL, pci bus id: <undefined>)


In [4]:
def train_model(model, model_name):
    checkpoint = ModelCheckpoint(f'{model_name}_best_model.h5', save_best_only=True, monitor='val_accuracy', mode='max', verbose=1)
    early_stopping = EarlyStopping(monitor='val_loss', mode='min', patience=10, verbose=1)

    history = model.fit(x_train, y_train,
                        batch_size=128,
                        epochs=20,
                        validation_data=(x_test, y_test),
                        callbacks=[checkpoint, early_stopping])

    return history


In [7]:
def plot_history(history, model_name):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 4))
    
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title(f'{model_name} Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title(f'{model_name} Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Test'], loc='upper left')

    plt.show()


In [8]:
# Define the models to compare
models_to_compare = {
    'VGG16': VGG16(input_shape=(32, 32, 3), include_top=False, weights='imagenet'),
    'VGG19': VGG19(input_shape=(32, 32, 3), include_top=False, weights='imagenet'),
    'ResNet50': ResNet50(input_shape=(32, 32, 3), include_top=False, weights='imagenet'),
    'EfficientNetB0': efn.EfficientNetB0(input_shape=(32, 32, 3), include_top=False, weights='imagenet')
}

# Train each model and print the results
for model_name, base_model in models_to_compare.items():
    print(f"Training {model_name}...")
    model = build_model(base_model, model_name)
    history = train_model(model, model_name)

    # Evaluate the model
    train_loss, train_acc = model.evaluate(x_train, y_train)
    test_loss, test_acc = model.evaluate(x_test, y_test)

    print(f"{model_name} Results:")
    print(f"Training Error: {train_loss:.4f}")
    print(f"Training Accuracy: {train_acc:.4f}")
    print(f"Testing Error: {test_loss:.4f}")
    print(f"Testing Accuracy: {test_acc:.4f}")
    
    # Plot the training loss and training accuracy over epochs
    plot_history(history, model_name)
    print("\n")


Downloading data from https://storage.googleapis.com/tensorflow/keras-applications/vgg16/vgg16_weights_tf_dim_ordering_tf_kernels_notop.h5

KeyboardInterrupt: 