In [None]:
# auto-reload all helper files
%load_ext autoreload
%autoreload 2

In [None]:
# Imports
import numpy as np
import pandas as pd
import tensorflow as tf
import matplotlib.pyplot as plt
import seaborn as sns
from sklearn.metrics import classification_report

from keras.callbacks import EarlyStopping

import config

from dataset import get_cifar10_datasets
from model import build_model
from train import train_model, compile_model
from metrics import (
    evaluate_model,
    merge_histories,
    predict_classes,
    confusion_matrix
)

# Run on training only - Commented out for evaluation
# tf.random.set_seed(config.RANDOM_SEED)
# np.random.seed(config.RANDOM_SEED)

In [None]:
# load data
train_ds, val_ds, test_ds = get_cifar10_datasets(
    batch_size=config.BATCH_SIZE
)

In [None]:
# Build and compile - # Run ONCE per model lifecycle as it resets model weights

model = build_model(
    strategy=config.MODEL_STRATEGY,
    input_shape=config.INPUT_SHAPE,
    num_classes=config.NUM_CLASSES
)

In [None]:
# Sanity Check
model.summary()

In [None]:
# Defining training phases for Training model only. Can be ignored for Scratch strategy.

if config.MODEL_STRATEGY == "transfer":
    PHASES = [
        {
            "name": "feature_extraction",
            "epochs": 3,
            "learning_rate": 1e-4,
            "backbone_trainable": False,
            # "base_callbacks": [
            #     EarlyStopping(
            #         monitor="val_loss",
            #         patience=3,
            #         restore_best_weights=True
            #     )
            # ]
        },
        {
            "name": "fine_tuning",
            "epochs": 3,
            "learning_rate": 1e-5,
            "backbone_trainable": True,
            # "base_callbacks": [
            #     EarlyStopping(
            #         monitor="val_loss",
            #         patience=3,
            #         restore_best_weights=True
            #     )
            #]
        }
    ]
else:  # scratch
    PHASES = [
        {
            "name": "training",
            "epochs": config.EPOCHS,
            "learning_rate": config.LEARNING_RATE,
            "base_callbacks": [
                EarlyStopping(
                    monitor="val_loss",
                    patience=5,
                    restore_best_weights=True
                )
            ]
        }
    ]


In [None]:
# Train model

print(">>> TRAINING STARTING <<<")

histories = train_model(
    model=model,
    train_data=train_ds,
    val_data=val_ds,
    compile_fn=compile_model,
    phases=PHASES
)

In [None]:
history = merge_histories(histories)

In [None]:
plt.plot(history["accuracy"], label="Train Accuracy")
plt.plot(history["val_accuracy"], label="Validation Accuracy")
plt.legend()
plt.show()

plt.plot(history["loss"], label="Train Loss")
plt.plot(history["val_loss"], label="Validation Loss")
plt.legend()
plt.show()

In [None]:
# Evaluate works with saved models

test_loss, test_acc = evaluate_model(model, test_ds)
print(f"Test loss: {test_loss:.4f}")
print(f"Test accuracy: {test_acc:.2%}")


In [None]:
y_true = np.concatenate([y for _, y in test_ds])
y_pred = predict_classes(model, test_ds)

In [None]:
cm = confusion_matrix(
    y_true,
    y_pred,
    num_classes=len(config.CLASS_NAMES)
)

# Normalize to percentages per true class
cm_percent = cm / cm.sum(axis=1, keepdims=True) * 100

# Create annotations with % sign
annot = np.array([[f"{v:.1f}%" for v in row] for row in cm_percent])

# Plot as heatmap for interpretability
plt.figure(figsize=(10, 8))
sns.heatmap(
    cm_percent,
    annot=annot,
    fmt="",
    cmap='Blues',
    xticklabels=config.CLASS_NAMES,
    yticklabels=config.CLASS_NAMES,
    cbar=True
)

plt.xlabel("Predicted label")
plt.ylabel("True label")
plt.title("Confusion Matrix (%)")
plt.tight_layout()
plt.show()

In [None]:
report = classification_report(
    y_true,
    y_pred,
    target_names=config.CLASS_NAMES,
    output_dict=True
)

df_report = pd.DataFrame(report).T.round(3)

df_report = pd.DataFrame(report).T.iloc[:-3]

# Plot
df_report[["precision", "recall", "f1-score"]].plot(
    kind="bar",
    figsize=(12, 5)
)

plt.ylabel("Score")
plt.ylim(0, 1)
plt.title("Classification Report Metrics per Class")
plt.xticks(rotation=45, ha="right")
plt.tight_layout()
plt.show()

# Save & Load Models

In these cells we have a save model and load logic.
The flag MODEL_SAVE_FLAG is set to False by default. It must be set to True to save the current model.
The Load Cell loads models stored under "/models/model_filename.keras"
After loading, evaluation metrics should reflect the models results and weights.
Running build.model() or compile_model() will reset its values.

Note: Model files are intentionally excluded from version control. Each user is expected to train and save models locally.

In [None]:
# ============================
# SAVE MODEL SNAPSHOT (MANUAL)
# ============================

import os
from datetime import datetime

# Prevent missclick saves [Change to True to Save]
MODEL_SAVE_FLAG = False

# Give this run a clear name
RUN_NAME = "model-v9_vgg16_finetuned_upsampling_cifar10_best"

# Safety check
assert RUN_NAME, "RUN_NAME must be defined before saving the model"

# Ensure models directory exists
os.makedirs("models", exist_ok=True)

# Optional: add timestamp to avoid accidental overwrites
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")


MODEL_PATH = f"models/{RUN_NAME}_{timestamp}.keras"

# Save model snapshot [Check True at top to run]
if MODEL_SAVE_FLAG:
    model.save(MODEL_PATH)

if MODEL_SAVE_FLAG:
    model.save(MODEL_PATH)
    print(f"Model snapshot saved to: {MODEL_PATH}")
else:
    print("MODEL_SAVE_FLAG=False â€” model not saved")

In [None]:
# Path to saved model

from keras.models import load_model

MODEL_PATH = "models/model-v9_vgg16_finetuned_upsampling_cifar10_best_20260130_001736.keras"  # adjust path/name

# Load model
model = load_model(MODEL_PATH)

print("Model loaded successfully")
model.summary()