In [1]:
# ============================================================
# Fine-Tuning Xception (TFDS, Local, CPU-Friendly)
# ============================================================
#
# This notebook demonstrates fine-tuning using a deeper
# architecture (Xception) and contrasts it with MobileNetV2.
#
# IMPORTANT:
# ----------
# - We use tf_flowers (5 classes) via TFDS
# - We DO NOT run learning-rate decay experiments
# - We fine-tune ONLY the top layers for performance
#
# This keeps the notebook runnable on student laptops
# while preserving the core learning objective.
#
# ============================================================


# ----------------------------
# Imports
# ----------------------------
import os
import tensorflow as tf
import tensorflow_datasets as tfds
import pandas as pd
import numpy as np

print("TensorFlow version:", tf.__version__)


# ----------------------------
# Configuration
# ----------------------------
IMG_HEIGHT = 224
IMG_WIDTH = 224
IMG_CHANNELS = 3

BATCH_SIZE = 16          # Reduced for Xception on CPU
EPOCHS = 5
LEARNING_RATE = 1e-4

# Number of top layers to fine-tune
FINE_TUNE_AT = 20

CLASS_NAMES = ['daisy', 'dandelion', 'roses', 'sunflowers', 'tulips']

OUTPUT_DIR = "03b_xception_finetune_outputs"
os.makedirs(OUTPUT_DIR, exist_ok=True)


# ============================================================
# Load Dataset (TFDS)
# ============================================================

(ds_train, ds_val), ds_info = tfds.load(
    "tf_flowers",
    split=["train[:80%]", "train[80%:]"],
    as_supervised=True,
    with_info=True
)

print("Dataset loaded:")
print(ds_info)


# ============================================================
# Preprocessing Pipeline
# ============================================================
#
# Xception expects:
# - 224x224 RGB images
# - Float32 inputs
# - Xception-specific preprocessing
#
# ============================================================

def preprocess(image, label):
    image = tf.image.resize(image, (IMG_HEIGHT, IMG_WIDTH))
    image = tf.keras.applications.xception.preprocess_input(image)
    return image, label


train_ds = (
    ds_train
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .shuffle(1000)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

val_ds = (
    ds_val
    .map(preprocess, num_parallel_calls=tf.data.AUTOTUNE)
    .batch(BATCH_SIZE)
    .prefetch(tf.data.AUTOTUNE)
)

print("Training and validation datasets ready.")


# ============================================================
# Build Fine-Tuned Xception Model
# ============================================================
#
# Strategy:
# - Load Xception pretrained on ImageNet
# - Freeze most layers
# - Fine-tune only the top N layers
#
# Xception is significantly heavier than MobileNetV2,
# so careful fine-tuning is essential on CPU.
#
# ============================================================

def build_finetuned_xception(
    fine_tune_at=20,
    learning_rate=1e-4
):
    base_model = tf.keras.applications.Xception(
        weights="imagenet",
        include_top=False,
        input_shape=(IMG_HEIGHT, IMG_WIDTH, IMG_CHANNELS)
    )

    for layer in base_model.layers[:-fine_tune_at]:
        layer.trainable = False

    for layer in base_model.layers[-fine_tune_at:]:
        layer.trainable = True

    model = tf.keras.Sequential([
        base_model,
        tf.keras.layers.GlobalAveragePooling2D(),
        tf.keras.layers.Dense(len(CLASS_NAMES), activation="softmax", name="flower_prob")
    ])

    model.compile(
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"]
    )

    return model


model = build_finetuned_xception(
    fine_tune_at=FINE_TUNE_AT,
    learning_rate=LEARNING_RATE
)

model.summary()


# ============================================================
# Train Model
# ============================================================

history = model.fit(
    train_ds,
    validation_data=val_ds,
    epochs=EPOCHS
)


# ============================================================
# Save Training History (CSV)
# ============================================================

def save_training_history(history, filename):
    df = pd.DataFrame(history.history)
    path = os.path.join(OUTPUT_DIR, filename)
    df.to_csv(path, index=False)
    print(f"Saved training history to {path}")


save_training_history(history, "xception_training_history.csv")


# ============================================================
# Evaluate Predictions and Confusion Matrices
# ============================================================

def evaluate_and_save(model, dataset):
    y_true = []
    y_pred = []

    for images, labels in dataset:
        probs = model.predict(images, verbose=0)
        preds = tf.argmax(probs, axis=1)

        y_true.extend(labels.numpy())
        y_pred.extend(preds.numpy())

    y_true = np.array(y_true)
    y_pred = np.array(y_pred)

    # Per-sample predictions
    pred_df = pd.DataFrame({
        "true_label": [CLASS_NAMES[i] for i in y_true],
        "predicted_label": [CLASS_NAMES[i] for i in y_pred]
    })

    pred_path = os.path.join(OUTPUT_DIR, "xception_predictions.csv")
    pred_df.to_csv(pred_path, index=False)

    # Confusion matrix (raw)
    cm = tf.math.confusion_matrix(
        y_true,
        y_pred,
        num_classes=len(CLASS_NAMES)
    ).numpy()

    cm_df = pd.DataFrame(cm, index=CLASS_NAMES, columns=CLASS_NAMES)
    cm_path = os.path.join(OUTPUT_DIR, "confusion_matrix_raw.csv")
    cm_df.to_csv(cm_path)

    # Normalized confusion matrix
    cm_norm = cm / cm.sum(axis=1, keepdims=True)
    cm_norm_df = pd.DataFrame(cm_norm, index=CLASS_NAMES, columns=CLASS_NAMES)
    cm_norm_path = os.path.join(OUTPUT_DIR, "confusion_matrix_normalized.csv")
    cm_norm_df.to_csv(cm_norm_path)

    print("Saved evaluation artifacts:")
    print("-", pred_path)
    print("-", cm_path)
    print("-", cm_norm_path)


evaluate_and_save(model, val_ds)


# ============================================================
# Summary
# ============================================================
#
# - Xception is deeper and more computationally expensive
#   than MobileNetV2
# - Partial fine-tuning balances performance and runtime
# - Architecture choice impacts training time and behavior
#
# This concludes the comparison of transfer learning and
# fine-tuning strategies in Chapter 3.
#
# ============================================================


TensorFlow version: 2.9.1
Dataset loaded:
tfds.core.DatasetInfo(
    name='tf_flowers',
    full_name='tf_flowers/3.0.1',
    description="""
    A large set of images of flowers
    """,
    homepage='https://www.tensorflow.org/tutorials/load_data/images',
    data_dir='C:\\Users\\Jason Eckert\\tensorflow_datasets\\tf_flowers\\3.0.1',
    file_format=tfrecord,
    download_size=218.21 MiB,
    dataset_size=221.83 MiB,
    features=FeaturesDict({
        'image': Image(shape=(None, None, 3), dtype=uint8),
        'label': ClassLabel(shape=(), dtype=int64, num_classes=5),
    }),
    supervised_keys=('image', 'label'),
    disable_shuffling=False,
    nondeterministic_order=False,
    splits={
        'train': <SplitInfo num_examples=3670, num_shards=2>,
    },
    citation="""@ONLINE {tfflowers,
    author = "The TensorFlow Team",
    title = "Flowers",
    month = "jan",
    year = "2019",
    url = "http://download.tensorflow.org/example_images/flower_photos.tgz" }""",
)
Training and