In [None]:
import tensorflow as tf
import numpy as np
from alexnet_model import AlexNet

# Function to prepare data for training and testing
def train_alexnet():
    # load, resize, normalize image
    train_data = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255 # normalizes pixel values to the range [0,1]
    )
    
    test_data = tf.keras.preprocessing.image.ImageDataGenerator(
        rescale=1./255
    )
    
    # Loading Training and Testing Data
    train_generator = train_data.flow_from_directory(
        '/Users/maximilianstumpf/Desktop/UCLA/Math156 - Machine Learning/Project/Data/Train',
        target_size=(227, 227), # match AlexNet
        batch_size=128, # number of images per batch
        class_mode='categorical' # set labels to one-hot encoded vectors for multi-class
    )
    
    test_generator = test_data.flow_from_directory(
       '/Users/maximilianstumpf/Desktop/UCLA/Math156 - Machine Learning/Project/Data/Test',
        target_size=(227, 227),
        batch_size=128,
        class_mode='categorical'
    )
    
    # Create AlexNet Model
    model = AlexNet(num_classes=50)
    
    # Optimizer (SGD)
    optimizer = tf.keras.optimizers.SGD( 
        learning_rate=0.01, # initial learning rate
        momentum=0.9, # helps to accelerate convergence
        nesterov=True
    )
    
    # Compiling the model
    model.compile(
        optimizer=optimizer,
        loss='categorical_crossentropy', # loss function
        metrics=['accuracy'] # tracking accuracy
    )
    
    # Training the model
    history = model.fit(
        # trains on 20 epochs using data from "train_generator" and validates on "test_generator"
        train_generator,
        epochs=20,
        validation_data=test_generator,
        callbacks=[
            tf.keras.callbacks.LearningRateScheduler(
                lambda epoch: 0.01 if epoch < 10 else (0.001 if epoch < 15 else 0.0001) # dynamically adjusting learning rate (like in original AlexNet)
            ),
            tf.keras.callbacks.ModelCheckpoint(
                'alexnet_checkpoint.keras', # checkpoints whenever accuracy improves
                save_best_only=True,
                monitor='val_accuracy'
            )
        ]
    )
    
    # Savel final weights
    model.save_weights('alexnet_final.keras')
    return model, history

if __name__ == "__main__":
    model, history = train_alexnet()