In [1]:
import tensorflow as tf
import tensorflow.keras.layers as layers
import keras
import tensorflow_datasets as tfds
import matplotlib.pyplot as plt
import numpy as np

2024-06-27 22:22:49.056551: I tensorflow/core/platform/cpu_feature_guard.cc:210] This TensorFlow binary is optimized to use available CPU instructions in performance-critical operations.
To enable the following instructions: AVX2 FMA, in other operations, rebuild TensorFlow with the appropriate compiler flags.
  from .autonotebook import tqdm as notebook_tqdm


# Create the ResNet18 model

In [4]:
def resnet_block(input_tensor, filters, strides=(1, 1)):
    x = layers.Conv2D(filters, (3, 3), strides=strides, padding='same')(input_tensor)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(filters, (3, 3), padding='same')(x)
    x = layers.BatchNormalization()(x)

    # Only first layer of each block has stride 2
    if strides != (1, 1) or input_tensor.shape[-1] != filters:
        input_tensor = layers.Conv2D(filters, (1, 1), strides=strides, padding='same')(input_tensor)
        input_tensor = layers.BatchNormalization()(input_tensor)

    x = layers.Add()([x, input_tensor])
    x = layers.Activation('relu')(x)
    return x

def build_resnet18(input_shape, num_classes, include_head = True):
    inputs = layers.Input(shape=input_shape)
    x = layers.Conv2D(64, (7, 7), strides=(2, 2), padding='same')(inputs)
    x = layers.BatchNormalization()(x)
    x = layers.Activation('relu')(x)
    x = layers.Conv2D(64, (3, 3), strides=(2, 2), padding='same')(x)
    
    # Block 1
    x = resnet_block(x, 64)
    x = resnet_block(x, 64)
    
    # Block 2
    x = resnet_block(x, 128, strides=(2, 2))
    x = resnet_block(x, 128)
    
    # Block 3
    x = resnet_block(x, 256, strides=(2, 2))
    x = resnet_block(x, 256)
    
    # Blobck 4
    x = resnet_block(x, 512, strides=(2, 2))
    x = resnet_block(x, 512)
    
    # Pool and head
    if include_head:
        x = layers.GlobalAveragePooling2D()(x)
        outputs = layers.Dense(num_classes, activation='softmax', name = 'classifier_head')(x)
    
    model = keras.Model(inputs, outputs)
    return model

In [5]:
# Create the ResNet-18 model
input_shape = (224, 224, 3)
num_classes = 10
model = build_resnet18(input_shape, num_classes)

# Compile the model
model.compile(optimizer='adam', loss='categorical_crossentropy', metrics=['accuracy'])

# Print model summary
model.summary()

# Train model on CIFAR-10

In [None]:
AUTOTUNE = tf.data.experimental.AUTOTUNE
epochs = 50

def preprocess_image(image, label, image_size):
    image = tf.image.resize(image, (image_size, image_size))
    image = image / 255.0
    return image, label

def load_data(image_size, batch_size):
    (train_ds, val_ds), ds_info = tfds.load(
        'cifar10',
        split=['train', 'test'],
        as_supervised=True,
        with_info=True
    )

    num_classes = ds_info.features['label'].num_classes

    train_ds = train_ds.map(lambda image, label: preprocess_image(image, label, image_size))
    val_ds = val_ds.map(lambda image, label: preprocess_image(image, label, image_size))

    train_ds = train_ds.shuffle(buffer_size=1000).batch(batch_size, drop_remainder = True).prefetch(AUTOTUNE)
    val_ds = val_ds.batch(batch_size, drop_remainder = True).prefetch(AUTOTUNE)

    return train_ds, val_ds, ds_info, num_classes

image_size = 224
batch_size = 128
ds_train, ds_test, ds_info, num_classes = load_data(image_size, batch_size)

#-------------------------------------------------------------------------------------------------
# Prepare training arguments
# ------------------------------------------------------------------------------------------------
early_stopping = tf.keras.callbacks.EarlyStopping(monitor='val_loss', patience=10, restore_best_weights=True)
initial_learning_rate = 1e-3
decay_steps = len(ds_train) * (epochs - 10)
alpha = 1e-5 / initial_learning_rate
warmup_steps = len(ds_train) * 10
lr_schedule = keras.optimizers.schedules.CosineDecay(
    initial_learning_rate, decay_steps, warmup_target=1e-5,
    warmup_steps=warmup_steps
)
optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_schedule, weight_decay=5e-5)

model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train the model
history = model.fit(ds_train, epochs=epochs, validation_data=ds_test, callbacks = [early_stopping])

# Check training results and plot metrics

In [None]:
def plot_training_history(history):
    # Plot training & validation accuracy values
    plt.figure(figsize=(12, 6))
    plt.subplot(1, 2, 1)
    plt.plot(history.history['accuracy'])
    plt.plot(history.history['val_accuracy'])
    plt.title('Model accuracy')
    plt.ylabel('Accuracy')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper left')

    # Plot training & validation loss values
    plt.subplot(1, 2, 2)
    plt.plot(history.history['loss'])
    plt.plot(history.history['val_loss'])
    plt.title('Model loss')
    plt.ylabel('Loss')
    plt.xlabel('Epoch')
    plt.legend(['Train', 'Validation'], loc='upper right')

    plt.tight_layout()
    plt.show()

# Evaluate the model
loss, accuracy = model.evaluate(ds_test)
print(f"Test accuracy: {accuracy:.2f}")

plot_training_history(history)

# Test on some random samples from the dataset

In [None]:
def random_predictions(model, dataset, ds_info, num_images=8):
    class_names = {idx: name for idx, name in enumerate(ds_info.features['label'].names)}

    # Get a random sample of num_images from the dataset
    random_indices = np.random.choice(len(dataset), size=num_images, replace=False)
    ds_subset = dataset.unbatch().skip(random_indices[0]).take(num_images).batch(num_images)

    # Make predictions
    images, labels = next(iter(ds_subset))
    predicted_logits = model.predict(images)
    predicted_probabilities = tf.nn.softmax(predicted_logits, axis=-1)
    predicted_classes = np.argmax(predicted_probabilities, axis=-1)

    # Display results
    plt.figure(figsize=(15, 10))
    for i in range(num_images):
        plt.subplot(2, 4, i+1)
        plt.imshow((images[i].numpy() * 255).astype(np.uint8))
        if predicted_classes[i] == labels[i].numpy():
            color = 'green'
        else:
            color = 'red'

        plt.title(f'True: {class_names[labels[i].numpy()]}\nPredicted: {class_names[predicted_classes[i]]}', color=color)
        plt.axis('off')
    plt.tight_layout()
    plt.show()

random_predictions(model, ds_train, ds_info, num_images=8)