In [None]:
%pip install tensorflow matplotlib numpy

In [None]:
import tensorflow as tf
from tensorflow.keras import layers, models, optimizers
from tensorflow.keras.datasets import cifar10
import matplotlib.pyplot as plt
import numpy as np
import time

In [None]:
# Define the AlexNet architecture using Keras
def create_alexnet(num_classes=10):
    model = models.Sequential([
        # Conv1
        layers.Conv2D(64, kernel_size=3, strides=1, padding='same', activation='relu', input_shape=(224, 224, 3)),
        layers.MaxPooling2D(pool_size=2, strides=2),
        
        # Conv2
        layers.Conv2D(192, kernel_size=3, strides=1, padding='same', activation='relu'),
        layers.MaxPooling2D(pool_size=2, strides=2),
        
        # Conv3
        layers.Conv2D(384, kernel_size=3, padding='same', activation='relu'),
        
        # Conv4
        layers.Conv2D(256, kernel_size=3, padding='same', activation='relu'),
        
        # Conv5
        layers.Conv2D(256, kernel_size=3, padding='same', activation='relu'),
        layers.MaxPooling2D(pool_size=2, strides=2),
        
        # Flatten layer
        layers.Flatten(),
        
        # Fully connected layers
        layers.Dropout(0.5),
        layers.Dense(4096, activation='relu'),
        layers.Dropout(0.5),
        layers.Dense(4096, activation='relu'),
        layers.Dense(num_classes, activation='softmax')
    ])
    
    return model

In [None]:
# Data loading and preprocessing
def load_cifar10():
    (x_train, y_train), (x_test, y_test) = cifar10.load_data()
    
    # Resize images to 224x224
    x_train_resized = tf.image.resize(x_train, [224, 224])
    x_test_resized = tf.image.resize(x_test, [224, 224])
    
    # Normalize pixel values
    x_train_normalized = (x_train_resized / 127.5) - 1
    x_test_normalized = (x_test_resized / 127.5) - 1
    
    # Convert labels to one-hot encoding
    y_train = tf.keras.utils.to_categorical(y_train, 10)
    y_test = tf.keras.utils.to_categorical(y_test, 10)
    
    return (x_train_normalized, y_train), (x_test_normalized, y_test)

In [None]:
# Training function
def train_model(model, train_data, epochs=10, batch_size=32):
    (x_train, y_train) = train_data
    
    # Compile the model
    model.compile(
        optimizer=optimizers.SGD(learning_rate=0.01, momentum=0.9),
        loss='categorical_crossentropy',
        metrics=['accuracy']
    )
    
    # Learning rate scheduler
    lr_scheduler = tf.keras.callbacks.CosineDecay(
        initial_learning_rate=0.01,
        decay_steps=epochs
    )
    
    # Train the model
    history = model.fit(
        x_train, y_train,
        batch_size=batch_size,
        epochs=epochs,
        validation_split=0.1,
        callbacks=[
            tf.keras.callbacks.LearningRateScheduler(lr_scheduler)
        ]
    )
    
    return history

In [None]:
# Evaluation function
def evaluate_model(model, test_data):
    (x_test, y_test) = test_data
    test_loss, test_accuracy = model.evaluate(x_test, y_test)
    return test_accuracy * 100  # Convert to percentage

In [None]:
# Set parameters
epochs = 5
batch_size = 32

print("Loading CIFAR-10 dataset...")
train_data, test_data = load_cifar10()

print("\nCreating AlexNet model...")
model = create_alexnet(num_classes=10)
model.summary()

print("\nTraining custom AlexNet...")
history = train_model(model, train_data, epochs=epochs, batch_size=batch_size)

print("\nEvaluating model...")
test_accuracy = evaluate_model(model, test_data)
print(f"Test Accuracy: {test_accuracy:.2f}%")

In [None]:
# Plot training curves
plt.figure(figsize=(12, 5))

plt.subplot(1, 2, 1)
plt.plot(history.history['loss'], 'r-', label='Training Loss')
plt.plot(history.history['val_loss'], 'b-', label='Validation Loss')
plt.title('Training and Validation Loss')
plt.xlabel('Epochs')
plt.ylabel('Loss')
plt.legend()

plt.subplot(1, 2, 2)
plt.plot(history.history['accuracy'], 'r-', label='Training Accuracy')
plt.plot(history.history['val_accuracy'], 'b-', label='Validation Accuracy')
plt.title('Training and Validation Accuracy')
plt.xlabel('Epochs')
plt.ylabel('Accuracy (%)')
plt.legend()

plt.tight_layout()
plt.show()