In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.applications import imagenet_utils
import numpy as np
import matplotlib.pyplot as plt
import time

In [None]:
# Define the AlexNet architecture
def create_alexnet(num_classes=10):
    model = models.Sequential([
        # Conv1
        layers.Conv2D(64, kernel_size=3, strides=1, padding='same', activation='relu', input_shape=(227, 227, 3)),
        layers.MaxPooling2D(pool_size=2, strides=2),
        
        # Conv2
        layers.Conv2D(192, kernel_size=3, strides=1, padding='same', activation='relu'),
        layers.MaxPooling2D(pool_size=2, strides=2),
        
        # Conv3
        layers.Conv2D(384, kernel_size=3, padding='same', activation='relu'),
        
        # Conv4
        layers.Conv2D(256, kernel_size=3, padding='same', activation='relu'),
        
        # Conv5
        layers.Conv2D(256, kernel_size=3, padding='same', activation='relu'),
        layers.MaxPooling2D(pool_size=2, strides=2),
        
        # Fully connected layers
        layers.Flatten(),
        layers.Dropout(0.5),
        layers.Dense(4096, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(4096, activation='relu'),
        layers.Dense(num_classes)
    ])
    
    return model

In [None]:
# Data loading and preprocessing
def load_cifar10():
    (x_train, y_train), (x_test, y_test) = tf.keras.datasets.cifar10.load_data()
    
    # Convert to float32 and normalize
    x_train = x_train.astype('float32') / 255
    x_test = x_test.astype('float32') / 255
    
    # Resize images to 227x227 (AlexNet's original input size)
    x_train_resized = tf.image.resize(x_train, (227, 227))
    x_test_resized = tf.image.resize(x_test, (227, 227))
    
    # Apply ImageNet normalization
    x_train_normalized = imagenet_utils.preprocess_input(x_train_resized)
    x_test_normalized = imagenet_utils.preprocess_input(x_test_resized)
    
    # Convert labels to categorical
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    
    return (x_train_normalized, y_train), (x_test_normalized, y_test)

In [None]:
# Custom callback for timing epochs
class TimeHistory(tf.keras.callbacks.Callback):
    def on_train_begin(self, logs={}):
        self.times = []
        
    def on_epoch_begin(self, epoch, logs={}):
        self.epoch_time_start = time.time()
        
    def on_epoch_end(self, epoch, logs={}):
        self.times.append(time.time() - self.epoch_time_start)

In [None]:
# Set up GPU if available
print("Num GPUs Available: ", len(tf.config.list_physical_devices('GPU')))
device = "/GPU:0" if len(tf.config.list_physical_devices('GPU')) > 0 else "/CPU:0"
print(f"Using device: {device}")

# Load and preprocess data
(x_train, y_train), (x_test, y_test) = load_cifar10()

In [None]:
# Create and compile model
with tf.device(device):
    model = create_alexnet()
    
    # Compile model with SGD optimizer and cosine decay
    initial_learning_rate = 0.01
    epochs = 5
    decay_steps = epochs * len(x_train) // 64  # assuming batch_size=64
    
    lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
        initial_learning_rate, decay_steps)
    
    optimizer = optimizers.SGD(learning_rate=lr_schedule, 
                             momentum=0.9)
    
    model.compile(optimizer=optimizer,
                 loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
                 metrics=['accuracy'])

    # Train model
    time_callback = TimeHistory()
    history = model.fit(x_train, y_train,
                       batch_size=64,
                       epochs=epochs,
                       validation_data=(x_test, y_test),
                       callbacks=[time_callback])

In [None]:
# Evaluate model
test_loss, test_accuracy = model.evaluate(x_test, y_test)
print(f"\nTest accuracy: {test_accuracy*100:.2f}%")

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(range(1, epochs+1), history.history['loss'], 'r-', label='Training')
plt.plot(range(1, epochs+1), history.history['val_loss'], 'b-', label='Validation')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(range(1, epochs+1), history.history['accuracy'], 'r-', label='Training')
plt.plot(range(1, epochs+1), history.history['val_accuracy'], 'b-', label='Validation')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy')
plt.legend()

plt.tight_layout()
plt.show()

# Print average epoch time
avg_epoch_time = np.mean(time_callback.times)
print(f"\nAverage epoch time: {avg_epoch_time:.2f} seconds")