In [None]:
# auto-reload all helper files
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import config
import numpy as np
from tensorflow.keras.models import load_model
from data import get_cifar10_data
from model import build_model
from train import compile_model, train_model, evaluate_model
from metrics import confusion_matrix
import matplotlib.pyplot as plt
import seaborn as sns

In [None]:
# load data
x_train, y_train, x_test, y_test = get_cifar10_data()

In [None]:
# Build and compile - # Run ONCE per model lifecycle as it resets model weights

model = build_model()
model = compile_model(model)

In [None]:
# Train model

print(">>> TRAINING STARTING <<<")

history = train_model(
    model,
    x_train,
    y_train,
    batch_size=config.BATCH_SIZE,
    epochs=config.EPOCHS,
)

In [None]:
# Evaluate Testing Data

test_loss, test_accuracy = evaluate_model(model, x_test, y_test)
print(f"Test accuracy: {test_accuracy:.2%}")
print(f"Test Loss: {test_loss:.4f}")

In [None]:
# Model Evaluation Metrics

# Evaluate model on unseen test data
y_pred_prob = model.predict(x_test)

# Selecting predicted class with highest probability per sample
y_pred = np.argmax(y_pred_prob, axis=1)
y_true = y_test

# Visualizing classification performance using a confusion matrix (true vs predicted labels)
cm = confusion_matrix(y_true, y_pred, config.NUM_CLASSES)

cm_norm = cm.astype("float") / cm.sum(axis=1, keepdims=True)

plt.figure(figsize=(10, 8))
sns.heatmap(cm_norm, annot=True, fmt='.2%', cmap='Blues',
            xticklabels=config.CLASS_NAMES, yticklabels=config.CLASS_NAMES)
plt.xlabel('Predicted Labels')
plt.ylabel('True Labels')
plt.title('Confusion Matrix')
plt.xticks(rotation=45, ha='right')
plt.yticks(rotation=0)
plt.tight_layout()
plt.show()


In [None]:
# Training history visualization (loss)

val_loss = history.history['val_loss']
loss = history.history['loss']

# Dynamically generate epoch axis to support variable training length (e.g. EarlyStopping)
epoch_axis = range(1, len(history.history["loss"]) + 1)

plt.figure()
plt.plot(epoch_axis, loss, label="Training loss")
plt.plot(epoch_axis, val_loss, label="Validation loss")
plt.xlabel("Epochs")
plt.ylabel("Loss")
plt.legend()
plt.title("Training vs Validation Loss")
plt.show()

# Extract accuracy metrics from training history
accuracy = history.history['accuracy']
val_accuracy = history.history['val_accuracy']

plt.figure()
plt.plot(epoch_axis, accuracy, label="Training accuracy")
plt.plot(epoch_axis, val_accuracy, label="Validation accuracy")
plt.xlabel("Epochs")
plt.ylabel("Accuracy")
plt.legend()
plt.title("Training vs Validation Accuracy")
plt.show()


# Snapshot final epoch performance
print(f'Training Loss: {loss[-1]:4f}')
print(f'Validation Loss: {val_loss[-1]:4f}')
print(f'Training Accuracy: {accuracy[-1]:.2%}')
print(f'Validation Accuracy: {val_accuracy[-1]:.2%}')
print('\n')

# Print out the best epoch where val_loss was at a minimum
best_epoch = np.argmin(val_loss)

# Diagnostic metrics: best epoch based on minimum validation loss
print(f"Best Epoch: {best_epoch + 1}")
print(f"Best Training Loss: {loss[best_epoch]:.4f}")
print(f"Best Validation Loss: {val_loss[best_epoch]:.4f}")
print(f"Best Training Accuracy: {accuracy[best_epoch]:.2%}")
print(f"Best Validation Accuracy: {val_accuracy[best_epoch]:.2%}")


In [None]:
# ============================
# SAVE MODEL SNAPSHOT (MANUAL)
# ============================

import os
from datetime import datetime

MODEL_SAVE_FLAG = False

# Give this run a clear name (you control this)
RUN_NAME = "model-v5_deeper_progressive_conv_layers"

# Safety check
assert RUN_NAME, "RUN_NAME must be defined before saving the model"

# Ensure models directory exists
os.makedirs("models", exist_ok=True)

# Optional: add timestamp to avoid accidental overwrites
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")


MODEL_PATH = f"models/{RUN_NAME}_{timestamp}.keras"

# Save model snapshot [Check True at top to run]
if MODEL_SAVE_FLAG:
    model.save(MODEL_PATH)

if MODEL_SAVE_FLAG:
    model.save(MODEL_PATH)
    print(f"Model snapshot saved to: {MODEL_PATH}")
else:
    print("MODEL_SAVE_FLAG=False â€” model not saved")