
# Vision AI — Cats vs Dogs 





## 1) Setup

In [None]:

import os, sys, math, itertools, pathlib, time
import numpy as np
import tensorflow as tf
from tensorflow import keras
from tensorflow.keras import layers
import matplotlib.pyplot as plt

!pip -q install -U tensorflow-datasets scikit-learn
import tensorflow_datasets as tfds
from sklearn.metrics import classification_report, confusion_matrix, roc_curve, auc

print("TensorFlow:", tf.__version__)
print("TF GPU available:", tf.config.list_physical_devices('GPU'))


## 2) Dataset — Cats vs Dogs (TFDS)

In [None]:

IMG_SIZE = 224
BATCH = 32
SEED = 42

(raw_train, raw_val, raw_test), ds_info = tfds.load(
    "cats_vs_dogs",
    split=["train[:80%]", "train[80%:90%]", "train[90%:]"],
    as_supervised=True,
    with_info=True
)

NUM_CLASSES = ds_info.features["label"].num_classes
CLASS_NAMES = ds_info.features["label"].names
print("Classes:", CLASS_NAMES, "| Train/Val/Test sizes:", 
      tf.data.experimental.cardinality(raw_train).numpy(),
      tf.data.experimental.cardinality(raw_val).numpy(),
      tf.data.experimental.cardinality(raw_test).numpy())


## 3) Preprocessing & Visualisation

In [None]:

def format_example(image, label):
    image = tf.image.resize(image, (IMG_SIZE, IMG_SIZE))
    image = tf.cast(image, tf.float32) / 255.0
    return image, label

train = raw_train.map(format_example, num_parallel_calls=tf.data.AUTOTUNE)
val   = raw_val.map(format_example, num_parallel_calls=tf.data.AUTOTUNE)
test  = raw_test.map(format_example, num_parallel_calls=tf.data.AUTOTUNE)

def prepare(ds, shuffle=False):
    if shuffle:
        ds = ds.shuffle(2048, seed=SEED, reshuffle_each_iteration=True)
    ds = ds.batch(BATCH).prefetch(tf.data.AUTOTUNE)
    return ds

train_ds = prepare(train, shuffle=True)
val_ds   = prepare(val)
test_ds  = prepare(test)

# Visualize a few samples
plt.figure(figsize=(8,8))
for i, (img, label) in enumerate(train.take(9).unbatch()):
    ax = plt.subplot(3,3,i+1)
    plt.imshow(img.numpy())
    plt.title(CLASS_NAMES[int(label.numpy())])
    plt.axis("off")
plt.show()


## 4) Data Augmentation

In [None]:

data_augmentation = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.1),
    layers.RandomZoom(0.1),
], name="augmentation")


## 5) Baseline CNN (from scratch)

In [None]:

def build_baseline(input_shape=(IMG_SIZE, IMG_SIZE, 3)):
    inputs = layers.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = layers.Conv2D(32, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(64, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Conv2D(128, 3, activation="relu")(x)
    x = layers.MaxPooling2D()(x)
    x = layers.Flatten()(x)
    x = layers.Dropout(0.3)(x)
    x = layers.Dense(128, activation="relu")(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)  # binary classification
    model = keras.Model(inputs, outputs, name="baseline_cnn")
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=["accuracy"]
    )
    return model

baseline = build_baseline()
baseline.summary()


### Train Baseline CNN

In [None]:

ckpt_dir = "checkpoints"
os.makedirs(ckpt_dir, exist_ok=True)

callbacks = [
    keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True, monitor="val_accuracy"),
    keras.callbacks.ModelCheckpoint(os.path.join(ckpt_dir, "baseline_best.keras"), 
                                    save_best_only=True, monitor="val_accuracy")
]

hist_baseline = baseline.fit(
    train_ds, 
    validation_data=val_ds,
    epochs=20,
    callbacks=callbacks
)


### Training Curves — Baseline

In [None]:

def plot_history(history, title="Training Curves"):
    h = history.history
    plt.figure(figsize=(6,4))
    plt.plot(h["accuracy"], label="train_acc")
    plt.plot(h["val_accuracy"], label="val_acc")
    plt.xlabel("Epoch"); plt.ylabel("Accuracy"); plt.title(title); plt.legend(); plt.show()
    
    plt.figure(figsize=(6,4))
    plt.plot(h["loss"], label="train_loss")
    plt.plot(h["val_loss"], label="val_loss")
    plt.xlabel("Epoch"); plt.ylabel("Loss"); plt.title(title); plt.legend(); plt.show()

plot_history(hist_baseline, "Baseline CNN")


## 6) Evaluation Utilities

In [None]:

def evaluate_model(model, ds, name="Model"):
    # Collect predictions and labels
    y_true = []
    y_prob = []
    for batch_images, batch_labels in ds:
        y_true.extend(batch_labels.numpy().tolist())
        y_prob.extend(model.predict(batch_images, verbose=0).ravel().tolist())
    y_true = np.array(y_true)
    y_prob = np.array(y_prob)
    y_pred = (y_prob >= 0.5).astype(int)

    # Metrics
    print(f"\n{name} — Classification Report:")
    print(classification_report(y_true, y_pred, target_names=CLASS_NAMES, digits=4))

    # Confusion Matrix
    cm = confusion_matrix(y_true, y_pred)
    plt.figure(figsize=(4,4))
    plt.imshow(cm, interpolation='nearest')
    plt.title(f"{name} — Confusion Matrix")
    plt.xticks([0,1], CLASS_NAMES); plt.yticks([0,1], CLASS_NAMES)
    for i in range(2):
        for j in range(2):
            plt.text(j, i, cm[i, j], ha="center", va="center")
    plt.xlabel("Predicted"); plt.ylabel("True"); plt.colorbar(); plt.show()

    # ROC Curve
    fpr, tpr, _ = roc_curve(y_true, y_prob)
    roc_auc = auc(fpr, tpr)
    plt.figure(figsize=(5,4))
    plt.plot(fpr, tpr, label=f"AUC = {roc_auc:.4f}")
    plt.plot([0,1],[0,1],"--")
    plt.xlabel("False Positive Rate"); plt.ylabel("True Positive Rate")
    plt.title(f"{name} — ROC Curve"); plt.legend(loc="lower right"); plt.show()

    return {
        "accuracy": (y_true == y_pred).mean(),
        "auc": roc_auc
    }

metrics_baseline = evaluate_model(baseline, test_ds, "Baseline CNN")
print("Baseline metrics:", metrics_baseline)


## 7) Transfer Learning — MobileNetV2

In [None]:

def build_mobilenet(input_shape=(IMG_SIZE, IMG_SIZE, 3), train_base=False):
    base = keras.applications.MobileNetV2(
        input_shape=input_shape, include_top=False, weights="imagenet"
    )
    base.trainable = train_base  # start frozen
    inputs = layers.Input(shape=input_shape)
    x = data_augmentation(inputs)
    x = keras.applications.mobilenet_v2.preprocess_input(x*255.0)  # MobileNet expects [-1,1] via its preprocess
    x = base(x, training=False)
    x = layers.GlobalAveragePooling2D()(x)
    x = layers.Dropout(0.2)(x)
    outputs = layers.Dense(1, activation="sigmoid")(x)
    model = keras.Model(inputs, outputs, name="mobilenetv2_transfer")
    model.compile(
        optimizer=keras.optimizers.Adam(1e-3),
        loss="binary_crossentropy",
        metrics=["accuracy"]
    )
    return model

mobilenet = build_mobilenet()
mobilenet.summary()

callbacks_tl = [
    keras.callbacks.EarlyStopping(patience=5, restore_best_weights=True, monitor="val_accuracy"),
    keras.callbacks.ModelCheckpoint(os.path.join(ckpt_dir, "mobilenet_best.keras"), 
                                    save_best_only=True, monitor="val_accuracy")
]

history_tl = mobilenet.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks_tl
)

# Optional fine-tuning: unfreeze last N layers
unfreeze_from = 100
mobilenet.get_layer(index=1).trainable = True  # the base model
for layer in mobilenet.get_layer(index=1).layers[:unfreeze_from]:
    layer.trainable = False

mobilenet.compile(optimizer=keras.optimizers.Adam(1e-4),
                  loss="binary_crossentropy",
                  metrics=["accuracy"])

history_ft = mobilenet.fit(
    train_ds,
    validation_data=val_ds,
    epochs=10,
    callbacks=callbacks_tl
)


### Training Curves — Transfer Learning

In [None]:

plot_history(history_tl, "MobileNetV2 (frozen base)")
plot_history(history_ft, "MobileNetV2 (fine‑tuned)")

metrics_mobilenet = evaluate_model(mobilenet, test_ds, "MobileNetV2 (fine‑tuned)")
print("Transfer learning metrics:", metrics_mobilenet)


## 8) Compare Models

In [None]:

print("\n=== Summary ===")
print("Baseline — Acc: %.4f  AUC: %.4f" % (metrics_baseline["accuracy"], metrics_baseline["auc"]))
print("MobileNetV2 — Acc: %.4f  AUC: %.4f" % (metrics_mobilenet["accuracy"], metrics_mobilenet["auc"]))
best = "MobileNetV2" if metrics_mobilenet["accuracy"] >= metrics_baseline["accuracy"] else "Baseline"
print("Best model:", best)


## 9) Save Artifacts & Inference Demo

In [None]:

save_dir = "artifacts"
os.makedirs(save_dir, exist_ok=True)

baseline.save(os.path.join(save_dir, "baseline_cnn.keras"))
mobilenet.save(os.path.join(save_dir, "mobilenetv2_finetuned.keras"))

# Inference demo: pick a few test images
for images, labels in test_ds.take(1):
    probs = mobilenet.predict(images).ravel()
    preds = (probs >= 0.5).astype(int)
    plt.figure(figsize=(10,10))
    for i in range(9):
        ax = plt.subplot(3,3,i+1)
        plt.imshow(images[i].numpy())
        t = f"True: {CLASS_NAMES[int(labels[i])]}\nPred: {CLASS_NAMES[int(preds[i])]} ({probs[i]:.2f})"
        plt.title(t)
        plt.axis("off")
    plt.show()
