In [None]:
from tensorflow.keras.applications import MobileNetV2  # AlexNet is not directly available, can adapt similar networks
from tensorflow.keras import layers, models
from tensorflow.keras.utils import to_categorical
import tensorflow as tf

# Preprocess the dataset
data_train = tf.image.resize(data_train, [224, 224])  # Resize MNIST to 224x224
data_val = tf.image.resize(data_val, [224, 224])
data_test = tf.image.resize(data_test, [224, 224])

# Convert labels to categorical
labels_train_cat = to_categorical(labels_train, 10)
labels_val_cat = to_categorical(labels_val, 10)
labels_test_cat = to_categorical(labels_test, 10)

# Load AlexNet-like architecture
alexnet_model = models.Sequential([
    layers.Conv2D(96, kernel_size=11, strides=4, activation='relu', input_shape=(224, 224, 1)),
    layers.MaxPooling2D(pool_size=3, strides=2),
    layers.Conv2D(256, kernel_size=5, activation='relu'),
    layers.MaxPooling2D(pool_size=3, strides=2),
    layers.Conv2D(384, kernel_size=3, activation='relu'),
    layers.Conv2D(384, kernel_size=3, activation='relu'),
    layers.Conv2D(256, kernel_size=3, activation='relu'),
    layers.MaxPooling2D(pool_size=3, strides=2),
    layers.Flatten(),
    layers.Dense(4096, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(4096, activation='relu'),
    layers.Dropout(0.5),
    layers.Dense(10, activation='softmax')
])

alexnet_model.compile(optimizer='adam',
                      loss='categorical_crossentropy',
                      metrics=['accuracy'])

# Train AlexNet
history_alexnet = alexnet_model.fit(data_train, labels_train_cat, epochs=20, batch_size=128,
                                    validation_data=(data_val, labels_val_cat))

# Evaluate AlexNet
test_loss_alexnet, test_acc_alexnet = alexnet_model.evaluate(data_test, labels_test_cat, verbose=2)
print(f"Test Accuracy (AlexNet): {test_acc_alexnet:.4f}")

# Plot AlexNet results
plt.figure(figsize=(10, 6))
plt.plot(history_alexnet.history['loss'], label='Training Loss', linestyle='--', marker='o')
plt.plot(history_alexnet.history['val_loss'], label='Validation Loss', linestyle='--', marker='o')
plt.title('Training and Validation Loss for AlexNet')
plt.xlabel('Epoch')
plt.ylabel('Loss')
plt.legend()
plt.grid(True)
plt.show()
