In [1]:
# ============================================================
# Fine-Tuning MobileNetV2 (Partial Unfreezing, TFDS, Local)
# ============================================================
#
# This notebook builds on 03a by demonstrating *fine-tuning*:
# instead of freezing the entire pretrained CNN, we selectively
# unfreeze the top layers of MobileNetV2 and continue training.
#
# IMPORTANT DESIGN CHOICE:
# ------------------------
# To keep runtime reasonable on CPU-only student laptops,
# we fine-tune ONLY the last N layers of MobileNetV2 rather
# than the entire network.
#
# This approach:
# - Preserves most pretrained features
# - Reduces computation dramatically
# - Clearly illustrates the concept of fine-tuning
#
# ============================================================


# ----------------------------
# Imports
# ----------------------------
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
import numpy as np

print("TensorFlow version:", tf.__version__)


# ----------------------------
# Configuration
# ----------------------------
IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3

BATCH_SIZE = 32
EPOCHS = 5               # Fewer epochs for CPU friendliness
LEARNING_RATE = 1e-4

# Number of layers (from the top of MobileNetV2) to fine-tune
FINE_TUNE_AT = 20        # Option A: fine-tune last N layers

CLASS_NAMES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

OUTPUT_DIR = "03b_mobilenet_finetune_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)


# ============================================================
# Load Dataset (TFDS)
# ============================================================

(ds_train, ds_val), ds_info = tfds.load(
    "tf_flowers",
    split=["train[:80%]", "train[80%:]"],
    as_supervised=True,
    with_info=True
)

print("Dataset loaded:")
print(ds_info)


# ============================================================
# Preprocessing Pipeline
# ============================================================
#
# MobileNetV2 expects:
# - 224x224 RGB images
# - Float32 inputs
# - MobileNetV2-specific preprocessing
#
# ============================================================

def preprocess(image, label):
    image = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH))
    image = tf.keras.applications.mobilenet_v2.preprocess_input(image)
    return image, label


train_ds = (
    ds_train
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(1000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = (
    ds_val
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

print("Training and validation datasets ready.")


# ============================================================
# Build Fine-Tuned MobileNetV2 Model
# ============================================================
#
# Strategy:
# - Load MobileNetV2 pretrained on ImageNet
# - Freeze MOST layers
# - Unfreeze ONLY the last `FINE_TUNE_AT` layers
#
# This is called *partial fine-tuning* and is far more
# practical than full fine-tuning on CPU hardware.
#
# ============================================================

def build_finetuned_model(
    fine_tune_at=20,
    learning_rate=1e-4
):
    base_model = tf.keras.applications.MobileNetV2(
        weights="imagenet",
        include_top=False,
        input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    )

    # Freeze all layers except the top N
    for layer in base_model.layers[:-fine_tune_at]:
        layer.trainable = False

    for layer in base_model.layers[-fine_tune_at:]:
        layer.trainable = True

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(16, activation="relu", name="dense_hidden"),
        tf.keras.layers.Dense(len(CLASS_NAMES), activation="softmax", name="flower_prob")
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


model = build_finetuned_model(
    fine_tune_at=FINE_TUNE_AT,
    learning_rate=LEARNING_RATE
)

model.summary()


# ============================================================
# Train Model
# ============================================================

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS
)


# ============================================================
# Save Training History (CSV)
# ============================================================

def save_training_history(history, filename):
    df = pd.DataFrame(history.history)
    path = os.path.join(OUTPUT_DIR, filename)
    df.to_csv(path, index=False)
    print(f"Saved training history to {path}")


save_training_history(history, "finetune_training_history.csv")


# ============================================================
# Evaluate Predictions and Confusion Matrix
# ============================================================
#
# We export:
# - Per-sample predictions
# - Raw confusion matrix
# - Normalized confusion matrix
#
# All results are saved as CSV files.
#
# ============================================================

def evaluate_and_save(model, dataset):
    y_true = []
    y_pred = []

    for images, labels in dataset:
        probs = model.predict(images, verbose=0)
        preds = tf.argmax(probs, axis=1)

        y_true.extend(labels.numpy())
        y_pred.extend(preds.numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Save per-sample predictions
    pred_df = pd.DataFrame({
        "true_label": [CLASS_NAMES[i] for i in y_true],
        "predicted_label": [CLASS_NAMES[i] for i in y_pred]
    })

    pred_path = os.path.join(OUTPUT_DIR, "finetune_predictions.csv")
    pred_df.to_csv(pred_path, index=False)

    # Confusion matrix (raw)
    cm = tf.math.confusion_matrix(
        y_true,
        y_pred,
        num_classes=len(CLASS_NAMES)
    ).numpy()

    cm_df = pd.DataFrame(cm, index=CLASS_NAMES, columns=CLASS_NAMES)
    cm_path = os.path.join(OUTPUT_DIR, "confusion_matrix_raw.csv")
    cm_df.to_csv(cm_path)

    # Confusion matrix (normalized)
    cm_norm = cm / cm.sum(axis=1, keepdims=True)
    cm_norm_df = pd.DataFrame(cm_norm, index=CLASS_NAMES, columns=CLASS_NAMES)
    cm_norm_path = os.path.join(OUTPUT_DIR, "confusion_matrix_normalized.csv")
    cm_norm_df.to_csv(cm_norm_path)

    print("Saved evaluation artifacts:")
    print("-", pred_path)
    print("-", cm_path)
    print("-", cm_norm_path)


evaluate_and_save(model, val_ds)


# ============================================================
# Summary
# ============================================================
#
# - We fine-tuned ONLY the top layers of MobileNetV2
# - Most pretrained weights remained frozen
# - Training took longer than pure transfer learning
# - This approach balances performance and practicality
#
# In practice, selective fine-tuning is often preferable
# to full fine-tuning, especially with limited compute.
#
# ============================================================


TensorFlow version: 2.9.1
Dataset loaded:
tfds.core.DatasetInfo(
    name='tf_flowers',
    full_name='tf_flowers/3.0.1',
    description="""
    A large set of images of flowers
    """,
    homepage='https://www.tensorflow.org/tutorials/load_data/images',
    data_dir='C:\\Users\\Jason Eckert\\tensorflow_datasets\\tf_flowers\\3.0.1',
    file_format=tfrecord,
    download_size=218.21 MiB,
    dataset_size=221.83 MiB,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    nondeterministic_order=False,
    splits={
        'train': <SplitInfo num_examples=3670, num_shards=2>,
    },
    citation="""@ONLINE {tfflowers,
    author = "The TensorFlow Team",
    title = "Flowers",
    month = "jan",
    year = "2019",
    url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }""",
)
Training and