In [2]:
import tensorflow as tf
import numpy as np
from alexnet_model import AlexNet
import matplotlib.pyplot as plt

# Function to prepare data for training, validation, and testing
def train_alexnet():
    # Create data generators
    train_data = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255,       # Normalize pixel values to the range [0,1]
    )
    
    test_data = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255
    )

    valid_data = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale = 1./255
    )
    
    # Loading Training and Validation Data
    train_generator = train_data.flow_from_directory(
        '/Users/maximilianstumpf/Desktop/UCLA/Math156 - Machine Learning/Project/Data/Train1',
        batch_size=128,          # Number of images per batch
        class_mode='categorical',  # One-hot encoding of labels
    )
    
    validation_generator = valid_data.flow_from_directory(
        '/Users/maximilianstumpf/Desktop/UCLA/Math156 - Machine Learning/Project/Data/Val',
        batch_size=128,
        class_mode='categorical',
    )
    
    # Loading Test Data
    test_generator = test_data.flow_from_directory(
        '/Users/maximilianstumpf/Desktop/UCLA/Math156 - Machine Learning/Project/Data/Test',
        batch_size=128,
        class_mode='categorical'
    )
    
    # Create AlexNet Model
    model = AlexNet(num_classes=50)
    
    # Learning Rate scheduler
    lr_schedule = tf.keras.optimizers.schedules.ExponentialDecay(
        initial_learning_rate = 0.01,
        decay_steps = 10000,
        decay_rate = 0.96,
        staircase = True
    )

    # Optimizer (SGD)
    optimizer = tf.keras.optimizers.SGD( 
        learning_rate=lr_schedule,
        momentum=0.9,        # Accelerates convergence
        nesterov=True,
    )
    
    # Compiling the model
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy',  # Loss function
        metrics=['accuracy']             # Track accuracy
    )
    
    # Training the model
    history = model.fit(
        train_generator,
        epochs=20,
        validation_data=validation_generator,
        callbacks=[
            #tf.keras.callbacks.LearningRateScheduler(
            #    lambda epoch: 0.01 if epoch < 10 else (0.001 if epoch < 15 else 0.0001)  # Dynamic LR adjustment
            #),
            tf.keras.callbacks.ModelCheckpoint(
                'alexnet_checkpoint.keras',  # Save checkpoints
                save_best_only=True,
                monitor='val_accuracy'
            ),
            tf.keras.callbacks.EarlyStopping(monitor = "val_loss", patience = 5, restore_best_weights = True)
        ]
    )
    
    # Save final weights
    model.save_weights('alexnet_final.weights.h5')
    
    # Evaluate on test data
    test_loss, test_accuracy = model.evaluate(test_generator)
    print(f"Test Loss: {test_loss:.4f}")
    print(f"Test Accuracy: {test_accuracy:.4f}")
    
    return model, history

# Run the training
if __name__ == "__main__":
    model, history = train_alexnet()

Found 30331 images belonging to 50 classes.
Found 7608 images belonging to 50 classes.
Found 9485 images belonging to 50 classes.
Epoch 1/20
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m432s[0m 2s/step - accuracy: 0.0218 - loss: 5.4023 - val_accuracy: 0.0355 - val_loss: 5.3050
Epoch 2/20
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m439s[0m 2s/step - accuracy: 0.0427 - loss: 5.2515 - val_accuracy: 0.0392 - val_loss: 5.2196
Epoch 3/20
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m407s[0m 2s/step - accuracy: 0.0505 - loss: 5.1217 - val_accuracy: 0.0577 - val_loss: 5.0455
Epoch 4/20
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m434s[0m 2s/step - accuracy: 0.0662 - loss: 4.9798 - val_accuracy: 0.0814 - val_loss: 4.8915
Epoch 5/20
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m449s[0m 2s/step - accuracy: 0.0924 - loss: 4.8164 - val_accuracy: 0.1092 - val_loss: 4.7281
Epoch 6/20
[1m237/237[0m [32m━━━━━━━━━━━━━━━━━━━━[0

KeyboardInterrupt: 