# AI Product Photo Detector -- Training Notebook

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nolancacheux/AI-Product-Photo-Detector/blob/main/notebooks/train_colab.ipynb)

Train an EfficientNet-B0 binary classifier to distinguish **real product photos** from **AI-generated images**.

**Requirements:** GPU runtime (T4 recommended). Go to `Runtime > Change runtime type > T4 GPU`.

---

## 1. Setup

In [None]:
# Install dependencies
!pip install -q torch torchvision timm datasets pillow scikit-learn matplotlib tqdm google-cloud-storage

In [None]:
import os
import json
import time
import random
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset, random_split
from torchvision import transforms
import timm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_curve, auc
)
from tqdm.auto import tqdm
from PIL import Image
from datasets import load_dataset

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Device: {device}")
if torch.cuda.is_available():
    print(f"GPU: {torch.cuda.get_device_name(0)}")
    print(f"Memory: {torch.cuda.get_device_properties(0).total_mem / 1e9:.1f} GB")

## 2. Configuration

Edit these parameters to customize training.

In [None]:
# -- Training Configuration --
CONFIG = {
    # Dataset
    "dataset_name": "date3k2/raw_real_fake_images",  # HuggingFace dataset ID
    "max_samples_per_class": 4000,   # Cap per class (None = use all)
    "val_ratio": 0.15,
    "test_ratio": 0.15,
    "image_size": 224,

    # Model
    "model_name": "efficientnet_b0",
    "pretrained": True,
    "dropout": 0.3,

    # Training
    "epochs": 15,
    "batch_size": 64,
    "learning_rate": 1e-4,
    "weight_decay": 1e-5,
    "patience": 5,          # Early stopping patience
    "num_workers": 2,

    # Output
    "output_dir": "./training_output",
    "model_filename": "best_model.pt",
}

os.makedirs(CONFIG["output_dir"], exist_ok=True)
print(json.dumps(CONFIG, indent=2))

## 3. Dataset

Download from HuggingFace and prepare train/val/test splits.

In [None]:
print(f"Loading dataset: {CONFIG['dataset_name']}...")
raw_dataset = load_dataset(CONFIG["dataset_name"], split="train")
print(f"Total samples: {len(raw_dataset)}")

# Inspect label distribution
labels = raw_dataset["label"]
unique, counts = np.unique(labels, return_counts=True)
label_names = raw_dataset.features["label"].names
for u, c in zip(unique, counts):
    print(f"  {label_names[u]}: {c} samples")

In [None]:
class HFImageDataset(Dataset):
    """Wraps a HuggingFace dataset for PyTorch training."""

    # Map dataset labels to binary: 0 = real, 1 = ai_generated
    # Adjust this mapping based on the dataset's label names
    LABEL_MAP = None  # Auto-detected below

    def __init__(self, hf_dataset, transform=None, label_map=None):
        self.dataset = hf_dataset
        self.transform = transform
        self.label_map = label_map or {}

    def __len__(self):
        return len(self.dataset)

    def __getitem__(self, idx):
        item = self.dataset[idx]
        image = item["image"].convert("RGB")
        label = self.label_map.get(item["label"], item["label"])

        if self.transform:
            image = self.transform(image)

        return image, label


# Auto-detect label mapping: real=0, ai/fake=1
label_names_lower = [n.lower() for n in label_names]
label_map = {}
for i, name in enumerate(label_names_lower):
    if any(kw in name for kw in ["real", "authentic", "genuine", "natural"]):
        label_map[i] = 0  # real
    elif any(kw in name for kw in ["fake", "ai", "generated", "synthetic", "art"]):
        label_map[i] = 1  # ai_generated
    else:
        label_map[i] = i  # keep original

print(f"Label mapping: {label_map}")
for i, name in enumerate(label_names):
    print(f"  {name} (original={i}) -> {label_map[i]} ({'AI-generated' if label_map[i] == 1 else 'Real'})")

In [None]:
# Balance and cap dataset
if CONFIG["max_samples_per_class"]:
    indices_by_class = {0: [], 1: []}
    for idx in range(len(raw_dataset)):
        mapped_label = label_map[raw_dataset[idx]["label"]]
        indices_by_class[mapped_label].append(idx)

    cap = CONFIG["max_samples_per_class"]
    selected = []
    for cls, indices in indices_by_class.items():
        random.shuffle(indices)
        selected.extend(indices[:cap])
        print(f"  Class {cls}: {min(len(indices), cap)} samples (from {len(indices)})")

    random.shuffle(selected)
    raw_dataset = raw_dataset.select(selected)
    print(f"Balanced dataset: {len(raw_dataset)} samples")

In [None]:
# Transforms
IMG_SIZE = CONFIG["image_size"]

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(10),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Split dataset
full_dataset = HFImageDataset(raw_dataset, transform=None, label_map=label_map)
total = len(full_dataset)
val_size = int(total * CONFIG["val_ratio"])
test_size = int(total * CONFIG["test_ratio"])
train_size = total - val_size - test_size

train_dataset, val_dataset, test_dataset = random_split(
    full_dataset, [train_size, val_size, test_size],
    generator=torch.Generator().manual_seed(SEED)
)

# Apply different transforms
class TransformSubset(Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform

    def __len__(self):
        return len(self.subset)

    def __getitem__(self, idx):
        image, label = self.subset[idx]
        if isinstance(image, Image.Image) and self.transform:
            image = self.transform(image)
        return image, label

train_ds = TransformSubset(train_dataset, train_transform)
val_ds = TransformSubset(val_dataset, val_transform)
test_ds = TransformSubset(test_dataset, val_transform)

print(f"Train: {len(train_ds)} | Val: {len(val_ds)} | Test: {len(test_ds)}")

# Dataloaders
train_loader = DataLoader(train_ds, batch_size=CONFIG["batch_size"], shuffle=True, num_workers=CONFIG["num_workers"], pin_memory=True)
val_loader = DataLoader(val_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"], pin_memory=True)
test_loader = DataLoader(test_ds, batch_size=CONFIG["batch_size"], shuffle=False, num_workers=CONFIG["num_workers"], pin_memory=True)

In [None]:
# Visualize samples
fig, axes = plt.subplots(2, 8, figsize=(20, 6))
fig.suptitle("Sample Images (Top: Real, Bottom: AI-Generated)", fontsize=14)

inv_normalize = transforms.Compose([
    transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
    transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
])

real_shown, fake_shown = 0, 0
for images, labels in train_loader:
    for img, lbl in zip(images, labels):
        img_show = inv_normalize(img).permute(1, 2, 0).clamp(0, 1).numpy()
        if lbl.item() == 0 and real_shown < 8:
            axes[0][real_shown].imshow(img_show)
            axes[0][real_shown].axis("off")
            real_shown += 1
        elif lbl.item() == 1 and fake_shown < 8:
            axes[1][fake_shown].imshow(img_show)
            axes[1][fake_shown].axis("off")
            fake_shown += 1
        if real_shown >= 8 and fake_shown >= 8:
            break
    if real_shown >= 8 and fake_shown >= 8:
        break

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/sample_images.png", dpi=150, bbox_inches="tight")
plt.show()

## 4. Model

EfficientNet-B0 with custom binary classification head.

In [None]:
class AIImageDetector(nn.Module):
    """Binary classifier for AI-generated image detection."""

    def __init__(self, model_name="efficientnet_b0", pretrained=True, dropout=0.3):
        super().__init__()
        self.backbone = timm.create_model(model_name, pretrained=pretrained, num_classes=0)
        self.feature_dim = self.backbone.num_features

        self.classifier = nn.Sequential(
            nn.Linear(self.feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 1),
        )

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)


model = AIImageDetector(
    model_name=CONFIG["model_name"],
    pretrained=CONFIG["pretrained"],
    dropout=CONFIG["dropout"],
).to(device)

total_params = sum(p.numel() for p in model.parameters())
trainable_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(f"Total parameters: {total_params:,}")
print(f"Trainable parameters: {trainable_params:,}")
print(f"Feature dimension: {model.feature_dim}")

## 5. Training

In [None]:
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"])


def train_epoch(model, loader, criterion, optimizer, device):
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(loader, desc="Train", leave=False):
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad()
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        predicted = (outputs > 0.0).float()
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels, all_probs = [], [], []

    for images, labels in tqdm(loader, desc="Eval", leave=False):
        images = images.to(device)
        labels_gpu = labels.float().unsqueeze(1).to(device)

        outputs = model(images)
        loss = criterion(outputs, labels_gpu)

        total_loss += loss.item() * images.size(0)
        probs = torch.sigmoid(outputs).cpu()
        predicted = (outputs > 0.0).float().cpu()
        correct += (predicted == labels_gpu.cpu()).sum().item()
        total += labels.size(0)

        all_preds.extend(predicted.squeeze().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs.squeeze().numpy())

    metrics = {
        "loss": total_loss / total,
        "accuracy": correct / total,
        "precision": precision_score(all_labels, all_preds, zero_division=0),
        "recall": recall_score(all_labels, all_preds, zero_division=0),
        "f1": f1_score(all_labels, all_preds, zero_division=0),
    }
    return metrics, np.array(all_labels), np.array(all_preds), np.array(all_probs)

In [None]:
# Training loop with early stopping
best_val_f1 = 0.0
patience_counter = 0
history = {"train_loss": [], "train_acc": [], "val_loss": [], "val_acc": [], "val_f1": [], "lr": []}

print(f"Training for {CONFIG['epochs']} epochs...")
print(f"{'Epoch':>5} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>8} | {'Val Acc':>7} | {'Val F1':>6} | {'LR':>10}")
print("-" * 75)

for epoch in range(CONFIG["epochs"]):
    t0 = time.time()

    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device)
    scheduler.step()

    lr = scheduler.get_last_lr()[0]
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_metrics["loss"])
    history["val_acc"].append(val_metrics["accuracy"])
    history["val_f1"].append(val_metrics["f1"])
    history["lr"].append(lr)

    elapsed = time.time() - t0
    marker = ""

    if val_metrics["f1"] > best_val_f1:
        best_val_f1 = val_metrics["f1"]
        patience_counter = 0
        marker = " << best"

        # Save checkpoint
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "config": CONFIG,
            "metrics": val_metrics,
            "history": history,
        }
        torch.save(checkpoint, f"{CONFIG['output_dir']}/{CONFIG['model_filename']}")
    else:
        patience_counter += 1

    print(f"{epoch+1:>5} | {train_loss:>10.4f} | {train_acc:>8.1%} | {val_metrics['loss']:>8.4f} | {val_metrics['accuracy']:>6.1%} | {val_metrics['f1']:>6.3f} | {lr:>10.2e} | {elapsed:.0f}s{marker}")

    if patience_counter >= CONFIG["patience"]:
        print(f"\nEarly stopping at epoch {epoch+1} (no improvement for {CONFIG['patience']} epochs)")
        break

print(f"\nBest validation F1: {best_val_f1:.4f}")

## 6. Evaluation

In [None]:
# Load best checkpoint
checkpoint = torch.load(f"{CONFIG['output_dir']}/{CONFIG['model_filename']}", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"Loaded best model from epoch {checkpoint['epoch'] + 1}")

# Evaluate on test set
test_metrics, y_true, y_pred, y_probs = evaluate(model, test_loader, criterion, device)

print("\n" + "=" * 50)
print("TEST SET RESULTS")
print("=" * 50)
for k, v in test_metrics.items():
    print(f"  {k:>12}: {v:.4f}")
print("\nClassification Report:")
print(classification_report(y_true, y_pred, target_names=["Real", "AI-Generated"]))

In [None]:
# Training curves
fig, axes = plt.subplots(1, 3, figsize=(18, 5))

axes[0].plot(history["train_loss"], label="Train", linewidth=2)
axes[0].plot(history["val_loss"], label="Validation", linewidth=2)
axes[0].set_title("Loss", fontsize=13)
axes[0].set_xlabel("Epoch")
axes[0].legend()
axes[0].grid(True, alpha=0.3)

axes[1].plot(history["train_acc"], label="Train", linewidth=2)
axes[1].plot(history["val_acc"], label="Validation", linewidth=2)
axes[1].set_title("Accuracy", fontsize=13)
axes[1].set_xlabel("Epoch")
axes[1].legend()
axes[1].grid(True, alpha=0.3)

axes[2].plot(history["val_f1"], label="Val F1", linewidth=2, color="green")
axes[2].set_title("Validation F1 Score", fontsize=13)
axes[2].set_xlabel("Epoch")
axes[2].legend()
axes[2].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/training_curves.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Confusion matrix + ROC curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
im = axes[0].imshow(cm, interpolation="nearest", cmap="Blues")
axes[0].set_title("Confusion Matrix", fontsize=13)
axes[0].set_xlabel("Predicted")
axes[0].set_ylabel("True")
axes[0].set_xticks([0, 1])
axes[0].set_yticks([0, 1])
axes[0].set_xticklabels(["Real", "AI"])
axes[0].set_yticklabels(["Real", "AI"])
for i in range(2):
    for j in range(2):
        color = "white" if cm[i, j] > cm.max() / 2 else "black"
        axes[0].text(j, i, str(cm[i, j]), ha="center", va="center", color=color, fontsize=16)
fig.colorbar(im, ax=axes[0])

# ROC curve
fpr, tpr, _ = roc_curve(y_true, y_probs)
roc_auc = auc(fpr, tpr)
axes[1].plot(fpr, tpr, linewidth=2, label=f"ROC (AUC = {roc_auc:.3f})")
axes[1].plot([0, 1], [0, 1], "k--", alpha=0.5)
axes[1].set_title("ROC Curve", fontsize=13)
axes[1].set_xlabel("False Positive Rate")
axes[1].set_ylabel("True Positive Rate")
axes[1].legend(fontsize=12)
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/evaluation.png", dpi=150, bbox_inches="tight")
plt.show()

## 7. Export and Deploy

Save the final metrics and optionally upload to Google Cloud Storage for automated deployment.

In [None]:
# Save metrics
results = {
    "config": CONFIG,
    "test_metrics": test_metrics,
    "best_epoch": int(checkpoint["epoch"]) + 1,
    "total_params": total_params,
    "history": history,
}
with open(f"{CONFIG['output_dir']}/results.json", "w") as f:
    json.dump(results, f, indent=2)

model_path = f"{CONFIG['output_dir']}/{CONFIG['model_filename']}"
model_size_mb = os.path.getsize(model_path) / (1024 * 1024)

print(f"Model saved: {model_path} ({model_size_mb:.1f} MB)")
print(f"Results saved: {CONFIG['output_dir']}/results.json")
print(f"\nTo use this model in the API, copy {CONFIG['model_filename']} to models/checkpoints/")

In [None]:
# (Optional) Upload to GCS for automated Cloud Run deployment
# Uncomment and fill in your bucket name to enable

UPLOAD_TO_GCS = False  # Set to True to upload
GCS_BUCKET = "ai-product-detector-487013"  # Your GCS bucket
GCS_MODEL_PATH = "models/best_model.pt"

if UPLOAD_TO_GCS:
    from google.cloud import storage

    # Authenticate (in Colab: use the auth widget below)
    # from google.colab import auth
    # auth.authenticate_user()

    client = storage.Client()
    bucket = client.bucket(GCS_BUCKET)
    blob = bucket.blob(GCS_MODEL_PATH)
    blob.upload_from_filename(model_path)
    print(f"Uploaded to gs://{GCS_BUCKET}/{GCS_MODEL_PATH}")
    print("The CD pipeline will automatically pick up this model on next deploy.")
else:
    print("GCS upload disabled. To deploy:")
    print(f"  1. Download {CONFIG['model_filename']} from this notebook")
    print(f"  2. Place it at models/checkpoints/best_model.pt in the repo")
    print(f"  3. Push to main -- CD will build and deploy automatically")

In [None]:
# Download the model file (Colab)
try:
    from google.colab import files
    files.download(model_path)
    print("Download started!")
except ImportError:
    print(f"Not running in Colab. Model is at: {model_path}")