# üîç AI Product Photo Detector ‚Äî Training Notebook

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/nolancacheux/AI-Product-Photo-Detector/blob/main/notebooks/train_colab.ipynb)

Train an **EfficientNet-B0** binary classifier to distinguish **real product photos** from **AI-generated images**.

## üìã What this notebook does:
1. **Setup** ‚Äî Install dependencies, authenticate with GCS
2. **Data** ‚Äî Download processed dataset from GCS bucket
3. **Training** ‚Äî Train EfficientNet-B0 with the project's training pipeline
4. **Evaluation** ‚Äî Compute accuracy, F1, precision, recall
5. **Visualization** ‚Äî Training curves, confusion matrix, Grad-CAM heatmaps
6. **Export** ‚Äî Save model checkpoint to GCS

---

**‚ö†Ô∏è Requirements:** GPU runtime recommended. Go to `Runtime > Change runtime type > T4 GPU`.

## 1. üõ†Ô∏è Setup

Clone the repository, install dependencies, and authenticate with Google Cloud.

In [None]:
# Clone the repository
!git clone --depth 1 https://github.com/nolancacheux/AI-Product-Photo-Detector.git
%cd AI-Product-Photo-Detector

In [None]:
# Install dependencies
!pip install -q torch torchvision timm pillow scikit-learn matplotlib tqdm \
    google-cloud-storage pyyaml structlog grad-cam numpy pandas

In [None]:
# Authenticate with Google Cloud
from google.colab import auth
auth.authenticate_user()
print("‚úÖ GCS authentication successful")

In [None]:
import os
import sys
import json
import time
import random
from pathlib import Path

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
from torchvision import transforms
import timm
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score,
    confusion_matrix, classification_report, roc_curve, auc
)
from tqdm.auto import tqdm
from PIL import Image
from google.cloud import storage
import yaml

# Reproducibility
SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"üñ•Ô∏è Device: {device}")
if torch.cuda.is_available():
    print(f"   GPU: {torch.cuda.get_device_name(0)}")
    print(f"   Memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f} GB")

## 2. ‚öôÔ∏è Configuration

Load training configuration from `configs/train_config.yaml` with Colab-specific overrides.

In [None]:
# GCP Configuration
GCP_PROJECT = "ai-product-detector-487013"
GCS_BUCKET = "ai-product-detector-487013"
GCS_DATA_PREFIX = "data/processed"
GCS_MODEL_PATH = "models/best_model.pt"

# Load base config from repo
with open("configs/train_config.yaml", "r") as f:
    config = yaml.safe_load(f)

# Colab-specific overrides
CONFIG = {
    # Data
    "train_dir": "data/processed/train",
    "val_dir": "data/processed/val",
    "test_dir": "data/processed/test",
    "image_size": config.get("data", {}).get("image_size", 224),
    "batch_size": config.get("data", {}).get("batch_size", 64),
    "num_workers": 2,  # Colab limitation

    # Model
    "model_name": config.get("model", {}).get("name", "efficientnet_b0"),
    "pretrained": config.get("model", {}).get("pretrained", True),
    "dropout": config.get("model", {}).get("dropout", 0.3),

    # Training
    "epochs": config.get("training", {}).get("epochs", 15),
    "learning_rate": config.get("training", {}).get("learning_rate", 0.001),
    "weight_decay": config.get("training", {}).get("weight_decay", 0.0001),
    "patience": config.get("training", {}).get("early_stopping_patience", 5),

    # Output
    "output_dir": "./training_output",
    "model_filename": "best_model.pt",
}

os.makedirs(CONFIG["output_dir"], exist_ok=True)
print("üìã Configuration:")
print(json.dumps(CONFIG, indent=2))

## 3. üì¶ Data

Download processed dataset from GCS bucket.

In [None]:
def download_from_gcs(bucket_name: str, gcs_prefix: str, local_dir: str) -> int:
    """Download directory from GCS to local filesystem.
    
    Args:
        bucket_name: GCS bucket name.
        gcs_prefix: Prefix path in GCS.
        local_dir: Local directory to download to.
        
    Returns:
        Number of files downloaded.
    """
    client = storage.Client(project=GCP_PROJECT)
    bucket = client.bucket(bucket_name)
    blobs = list(bucket.list_blobs(prefix=gcs_prefix))
    
    downloaded = 0
    for blob in tqdm(blobs, desc=f"Downloading {gcs_prefix}"):
        if blob.name.endswith("/"):
            continue  # Skip directories
            
        # Create local path
        relative_path = blob.name[len(gcs_prefix):].lstrip("/")
        local_path = Path(local_dir) / relative_path
        local_path.parent.mkdir(parents=True, exist_ok=True)
        
        # Download
        blob.download_to_filename(str(local_path))
        downloaded += 1
        
    return downloaded


# Download training data
print(f"üì• Downloading data from gs://{GCS_BUCKET}/{GCS_DATA_PREFIX}/...")
n_files = download_from_gcs(GCS_BUCKET, GCS_DATA_PREFIX, "data/processed")
print(f"‚úÖ Downloaded {n_files} files")

In [None]:
# Verify data structure
def count_images(directory: str) -> dict:
    """Count images in each class subdirectory."""
    counts = {}
    base_path = Path(directory)
    if not base_path.exists():
        return counts
    for class_dir in base_path.iterdir():
        if class_dir.is_dir():
            n = len(list(class_dir.glob("*.[jJpP][pPnN][gG]*")))
            counts[class_dir.name] = n
    return counts

print("\nüìä Dataset Statistics:")
for split in ["train", "val", "test"]:
    split_dir = f"data/processed/{split}"
    counts = count_images(split_dir)
    total = sum(counts.values())
    print(f"  {split:>5}: {total:,} images", end="")
    if counts:
        print(f" ({', '.join(f'{k}: {v}' for k, v in counts.items())})", end="")
    print()

In [None]:
# Dataset class (matches src/training/dataset.py)
class AIProductDataset(Dataset):
    """Dataset for AI vs Real product image classification.
    
    Directory structure:
        data_dir/
        ‚îú‚îÄ‚îÄ real/
        ‚îÇ   ‚îî‚îÄ‚îÄ *.jpg
        ‚îî‚îÄ‚îÄ ai_generated/
            ‚îî‚îÄ‚îÄ *.jpg
    
    Labels: 0 = Real, 1 = AI-generated
    """

    def __init__(self, data_dir: str, transform=None, image_size: int = 224):
        self.data_dir = Path(data_dir)
        self.image_size = image_size
        self.transform = transform or self._default_transform()
        self.samples = []
        self._load_samples()

    def _default_transform(self):
        return transforms.Compose([
            transforms.Resize((self.image_size, self.image_size)),
            transforms.ToTensor(),
            transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
        ])

    def _load_samples(self):
        extensions = [".jpg", ".jpeg", ".png", ".webp"]
        
        # Real images (label = 0)
        real_dir = self.data_dir / "real"
        if real_dir.exists():
            for img_path in real_dir.iterdir():
                if img_path.suffix.lower() in extensions:
                    self.samples.append((img_path, 0))

        # AI-generated images (label = 1)
        ai_dir = self.data_dir / "ai_generated"
        if ai_dir.exists():
            for img_path in ai_dir.iterdir():
                if img_path.suffix.lower() in extensions:
                    self.samples.append((img_path, 1))

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        img_path, label = self.samples[idx]
        try:
            image = Image.open(img_path).convert("RGB")
            if self.transform:
                image = self.transform(image)
            return image, label
        except Exception as e:
            print(f"Error loading {img_path}: {e}")
            return self.__getitem__((idx + 1) % len(self.samples))

In [None]:
# Transforms (matches src/training/augmentation.py)
IMG_SIZE = CONFIG["image_size"]

train_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.RandomRotation(15),
    transforms.ColorJitter(brightness=0.2, contrast=0.2, saturation=0.1, hue=0.05),
    transforms.RandomResizedCrop(IMG_SIZE, scale=(0.8, 1.0)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

val_transform = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]),
])

# Create datasets
train_dataset = AIProductDataset(CONFIG["train_dir"], transform=train_transform)
val_dataset = AIProductDataset(CONFIG["val_dir"], transform=val_transform)
test_dataset = AIProductDataset(CONFIG["test_dir"], transform=val_transform)

print(f"\nüìä Datasets loaded:")
print(f"   Train: {len(train_dataset):,} samples")
print(f"   Val:   {len(val_dataset):,} samples")
print(f"   Test:  {len(test_dataset):,} samples")

# Create dataloaders
train_loader = DataLoader(
    train_dataset, 
    batch_size=CONFIG["batch_size"], 
    shuffle=True, 
    num_workers=CONFIG["num_workers"], 
    pin_memory=True,
    drop_last=True
)
val_loader = DataLoader(
    val_dataset, 
    batch_size=CONFIG["batch_size"], 
    shuffle=False, 
    num_workers=CONFIG["num_workers"], 
    pin_memory=True
)
test_loader = DataLoader(
    test_dataset, 
    batch_size=CONFIG["batch_size"], 
    shuffle=False, 
    num_workers=CONFIG["num_workers"], 
    pin_memory=True
)

In [None]:
# Visualize sample images
def show_samples(loader, n_per_class=4):
    """Display sample images from each class."""
    fig, axes = plt.subplots(2, n_per_class, figsize=(n_per_class * 3, 6))
    fig.suptitle("Sample Images\nTop: Real | Bottom: AI-Generated", fontsize=14)

    inv_normalize = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    ])

    real_shown, fake_shown = 0, 0
    for images, labels in loader:
        for img, lbl in zip(images, labels):
            img_show = inv_normalize(img).permute(1, 2, 0).clamp(0, 1).numpy()
            if lbl.item() == 0 and real_shown < n_per_class:
                axes[0][real_shown].imshow(img_show)
                axes[0][real_shown].axis("off")
                axes[0][real_shown].set_title("Real")
                real_shown += 1
            elif lbl.item() == 1 and fake_shown < n_per_class:
                axes[1][fake_shown].imshow(img_show)
                axes[1][fake_shown].axis("off")
                axes[1][fake_shown].set_title("AI-Generated")
                fake_shown += 1
            if real_shown >= n_per_class and fake_shown >= n_per_class:
                break
        if real_shown >= n_per_class and fake_shown >= n_per_class:
            break

    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/sample_images.png", dpi=150, bbox_inches="tight")
    plt.show()

show_samples(train_loader)

## 4. üß† Model

EfficientNet-B0 with custom binary classification head (matches `src/training/model.py`).

In [None]:
class AIImageDetector(nn.Module):
    """EfficientNet-based binary classifier for AI image detection.
    
    This class is redefined here to match src/training/model.py for standalone
    notebook execution in Colab without requiring the full repo in PYTHONPATH.
    """

    def __init__(
        self,
        model_name: str = "efficientnet_b0",
        pretrained: bool = True,
        dropout: float = 0.3,
        freeze_backbone: bool = False,
    ):
        super().__init__()

        # Load pretrained backbone
        self.backbone = timm.create_model(
            model_name,
            pretrained=pretrained,
            num_classes=0,  # Remove classifier head
        )

        # Freeze backbone if requested
        if freeze_backbone:
            for param in self.backbone.parameters():
                param.requires_grad = False

        # Get feature dimension
        with torch.no_grad():
            dummy = torch.randn(1, 3, 224, 224)
            features = self.backbone(dummy)
            feature_dim = features.shape[1]

        # Classification head (outputs raw logits for BCEWithLogitsLoss)
        self.classifier = nn.Sequential(
            nn.Linear(feature_dim, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(inplace=True),
            nn.Dropout(p=dropout),
            nn.Linear(512, 1),
        )

        self.feature_dim = feature_dim
        self.model_name = model_name

    def forward(self, x):
        features = self.backbone(x)
        return self.classifier(features)

    def predict_proba(self, x):
        self.eval()
        with torch.no_grad():
            return torch.sigmoid(self.forward(x))

    def get_num_trainable_params(self):
        return sum(p.numel() for p in self.parameters() if p.requires_grad)

    def get_num_total_params(self):
        return sum(p.numel() for p in self.parameters())


# Create model
model = AIImageDetector(
    model_name=CONFIG["model_name"],
    pretrained=CONFIG["pretrained"],
    dropout=CONFIG["dropout"],
).to(device)

print(f"\nüß† Model: {CONFIG['model_name']}")
print(f"   Feature dimension: {model.feature_dim}")
print(f"   Total parameters: {model.get_num_total_params():,}")
print(f"   Trainable parameters: {model.get_num_trainable_params():,}")

## 5. üèãÔ∏è Training

Training loop with early stopping and checkpointing.

In [None]:
# Loss, optimizer, scheduler
criterion = nn.BCEWithLogitsLoss()
optimizer = optim.AdamW(
    model.parameters(),
    lr=CONFIG["learning_rate"],
    weight_decay=CONFIG["weight_decay"],
)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=CONFIG["epochs"])


def train_epoch(model, loader, criterion, optimizer, device):
    """Train for one epoch."""
    model.train()
    total_loss, correct, total = 0.0, 0, 0

    for images, labels in tqdm(loader, desc="Training", leave=False):
        images = images.to(device)
        labels = labels.float().unsqueeze(1).to(device)

        optimizer.zero_grad(set_to_none=True)
        outputs = model(images)
        loss = criterion(outputs, labels)
        loss.backward()
        torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
        optimizer.step()

        total_loss += loss.item() * images.size(0)
        predicted = (outputs > 0.0).float()
        correct += (predicted == labels).sum().item()
        total += labels.size(0)

    return total_loss / total, correct / total


@torch.no_grad()
def evaluate(model, loader, criterion, device):
    """Evaluate model and compute all metrics."""
    model.eval()
    total_loss, correct, total = 0.0, 0, 0
    all_preds, all_labels, all_probs = [], [], []
    tp, fp, fn = 0, 0, 0

    for images, labels in tqdm(loader, desc="Evaluating", leave=False):
        images = images.to(device)
        labels_gpu = labels.float().unsqueeze(1).to(device)

        outputs = model(images)
        loss = criterion(outputs, labels_gpu)

        total_loss += loss.item() * images.size(0)
        probs = torch.sigmoid(outputs).cpu()
        predicted = (outputs > 0.0).float()
        correct += (predicted == labels_gpu).sum().item()
        total += labels.size(0)

        # For precision/recall
        tp += ((predicted == 1) & (labels_gpu == 1)).sum().item()
        fp += ((predicted == 1) & (labels_gpu == 0)).sum().item()
        fn += ((predicted == 0) & (labels_gpu == 1)).sum().item()

        all_preds.extend(predicted.squeeze().cpu().numpy())
        all_labels.extend(labels.numpy())
        all_probs.extend(probs.squeeze().numpy())

    precision = tp / (tp + fp + 1e-8)
    recall = tp / (tp + fn + 1e-8)
    f1 = 2 * precision * recall / (precision + recall + 1e-8)

    metrics = {
        "loss": total_loss / total,
        "accuracy": correct / total,
        "precision": precision,
        "recall": recall,
        "f1": f1,
    }
    return metrics, np.array(all_labels), np.array(all_preds), np.array(all_probs)

In [None]:
# Training loop with early stopping
best_val_acc = 0.0
patience_counter = 0
history = {
    "train_loss": [], "train_acc": [],
    "val_loss": [], "val_acc": [],
    "val_precision": [], "val_recall": [], "val_f1": [],
    "lr": []
}

print(f"\nüèãÔ∏è Training for {CONFIG['epochs']} epochs...")
print(f"{'Epoch':>5} | {'Train Loss':>10} | {'Train Acc':>9} | {'Val Loss':>8} | {'Val Acc':>7} | {'Val F1':>6} | {'LR':>10}")
print("-" * 80)

for epoch in range(CONFIG["epochs"]):
    t0 = time.time()

    train_loss, train_acc = train_epoch(model, train_loader, criterion, optimizer, device)
    val_metrics, _, _, _ = evaluate(model, val_loader, criterion, device)
    scheduler.step()

    lr = scheduler.get_last_lr()[0]
    history["train_loss"].append(train_loss)
    history["train_acc"].append(train_acc)
    history["val_loss"].append(val_metrics["loss"])
    history["val_acc"].append(val_metrics["accuracy"])
    history["val_precision"].append(val_metrics["precision"])
    history["val_recall"].append(val_metrics["recall"])
    history["val_f1"].append(val_metrics["f1"])
    history["lr"].append(lr)

    elapsed = time.time() - t0
    marker = ""

    if val_metrics["accuracy"] > best_val_acc:
        best_val_acc = val_metrics["accuracy"]
        patience_counter = 0
        marker = " ‚≠ê best"

        # Save checkpoint
        checkpoint = {
            "epoch": epoch,
            "model_state_dict": model.state_dict(),
            "optimizer_state_dict": optimizer.state_dict(),
            "scheduler_state_dict": scheduler.state_dict(),
            "val_accuracy": val_metrics["accuracy"],
            "best_val_accuracy": best_val_acc,
            "config": CONFIG,
            "history": history,
        }
        torch.save(checkpoint, f"{CONFIG['output_dir']}/{CONFIG['model_filename']}")
    else:
        patience_counter += 1

    print(f"{epoch+1:>5} | {train_loss:>10.4f} | {train_acc:>8.1%} | {val_metrics['loss']:>8.4f} | {val_metrics['accuracy']:>6.1%} | {val_metrics['f1']:>6.3f} | {lr:>10.2e} | {elapsed:.0f}s{marker}")

    if patience_counter >= CONFIG["patience"]:
        print(f"\n‚èπÔ∏è Early stopping at epoch {epoch+1} (no improvement for {CONFIG['patience']} epochs)")
        break

print(f"\n‚úÖ Training complete! Best validation accuracy: {best_val_acc:.4f}")

## 6. üìä Evaluation

Comprehensive evaluation on test set with all metrics.

In [None]:
# Load best checkpoint
checkpoint = torch.load(f"{CONFIG['output_dir']}/{CONFIG['model_filename']}", map_location=device)
model.load_state_dict(checkpoint["model_state_dict"])
print(f"‚úÖ Loaded best model from epoch {checkpoint['epoch'] + 1}")

# Evaluate on test set
test_metrics, y_true, y_pred, y_probs = evaluate(model, test_loader, criterion, device)

print("\n" + "=" * 60)
print("üìä TEST SET RESULTS")
print("=" * 60)
for k, v in test_metrics.items():
    print(f"  {k:>12}: {v:.4f}")

print("\nüìã Classification Report:")
print(classification_report(y_true, y_pred, target_names=["Real", "AI-Generated"], digits=4))

## 7. üìà Visualization

Training curves, confusion matrix, ROC curve, and Grad-CAM heatmaps.

In [None]:
# Training curves
fig, axes = plt.subplots(2, 2, figsize=(14, 10))

# Loss
axes[0, 0].plot(history["train_loss"], label="Train", linewidth=2)
axes[0, 0].plot(history["val_loss"], label="Validation", linewidth=2)
axes[0, 0].set_title("Loss", fontsize=13)
axes[0, 0].set_xlabel("Epoch")
axes[0, 0].set_ylabel("Loss")
axes[0, 0].legend()
axes[0, 0].grid(True, alpha=0.3)

# Accuracy
axes[0, 1].plot(history["train_acc"], label="Train", linewidth=2)
axes[0, 1].plot(history["val_acc"], label="Validation", linewidth=2)
axes[0, 1].set_title("Accuracy", fontsize=13)
axes[0, 1].set_xlabel("Epoch")
axes[0, 1].set_ylabel("Accuracy")
axes[0, 1].legend()
axes[0, 1].grid(True, alpha=0.3)

# Precision, Recall, F1
axes[1, 0].plot(history["val_precision"], label="Precision", linewidth=2)
axes[1, 0].plot(history["val_recall"], label="Recall", linewidth=2)
axes[1, 0].plot(history["val_f1"], label="F1", linewidth=2)
axes[1, 0].set_title("Validation Metrics", fontsize=13)
axes[1, 0].set_xlabel("Epoch")
axes[1, 0].legend()
axes[1, 0].grid(True, alpha=0.3)

# Learning rate
axes[1, 1].plot(history["lr"], linewidth=2, color="purple")
axes[1, 1].set_title("Learning Rate Schedule", fontsize=13)
axes[1, 1].set_xlabel("Epoch")
axes[1, 1].set_ylabel("Learning Rate")
axes[1, 1].set_yscale("log")
axes[1, 1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/training_curves.png", dpi=150, bbox_inches="tight")
plt.show()

In [None]:
# Confusion matrix + ROC curve
fig, axes = plt.subplots(1, 2, figsize=(14, 5))

# Confusion matrix
cm = confusion_matrix(y_true, y_pred)
im = axes[0].imshow(cm, interpolation="nearest", cmap="Blues")
axes[0].set_title("Confusion Matrix", fontsize=14)
axes[0].set_xlabel("Predicted", fontsize=12)
axes[0].set_ylabel("True", fontsize=12)
axes[0].set_xticks([0, 1])
axes[0].set_yticks([0, 1])
axes[0].set_xticklabels(["Real", "AI-Generated"])
axes[0].set_yticklabels(["Real", "AI-Generated"])
for i in range(2):
    for j in range(2):
        color = "white" if cm[i, j] > cm.max() / 2 else "black"
        axes[0].text(j, i, f"{cm[i, j]}\n({cm[i, j]/cm.sum()*100:.1f}%)", 
                     ha="center", va="center", color=color, fontsize=14)
fig.colorbar(im, ax=axes[0])

# ROC curve
fpr, tpr, _ = roc_curve(y_true, y_probs)
roc_auc = auc(fpr, tpr)
axes[1].plot(fpr, tpr, linewidth=2, color="darkorange", label=f"ROC (AUC = {roc_auc:.3f})")
axes[1].plot([0, 1], [0, 1], "k--", alpha=0.5, label="Random")
axes[1].fill_between(fpr, tpr, alpha=0.2, color="darkorange")
axes[1].set_title("ROC Curve", fontsize=14)
axes[1].set_xlabel("False Positive Rate", fontsize=12)
axes[1].set_ylabel("True Positive Rate", fontsize=12)
axes[1].legend(fontsize=11, loc="lower right")
axes[1].grid(True, alpha=0.3)

plt.tight_layout()
plt.savefig(f"{CONFIG['output_dir']}/evaluation.png", dpi=150, bbox_inches="tight")
plt.show()

print(f"\nüìä AUC-ROC: {roc_auc:.4f}")

In [None]:
# Grad-CAM visualization
try:
    from pytorch_grad_cam import GradCAM
    from pytorch_grad_cam.utils.image import show_cam_on_image
    from pytorch_grad_cam.utils.model_targets import BinaryClassifierOutputTarget

    # Get the last convolutional layer from backbone
    target_layers = [model.backbone.conv_head]  # EfficientNet's last conv layer
    cam = GradCAM(model=model, target_layers=target_layers)

    # Inverse normalization for visualization
    inv_normalize = transforms.Compose([
        transforms.Normalize(mean=[0, 0, 0], std=[1/0.229, 1/0.224, 1/0.225]),
        transforms.Normalize(mean=[-0.485, -0.456, -0.406], std=[1, 1, 1]),
    ])

    # Get sample images
    n_samples = 4
    fig, axes = plt.subplots(2, n_samples, figsize=(n_samples * 4, 8))
    fig.suptitle("Grad-CAM Visualization\nTop: Original | Bottom: Attention Heatmap", fontsize=14)

    shown = 0
    for images, labels in test_loader:
        for i, (img, lbl) in enumerate(zip(images, labels)):
            if shown >= n_samples:
                break
                
            # Prepare input
            input_tensor = img.unsqueeze(0).to(device)
            
            # Get prediction
            with torch.no_grad():
                prob = torch.sigmoid(model(input_tensor)).item()
            pred_label = "AI-Generated" if prob > 0.5 else "Real"
            true_label = "AI-Generated" if lbl.item() == 1 else "Real"
            
            # Generate CAM
            targets = [BinaryClassifierOutputTarget(1)]  # Target: AI-generated class
            grayscale_cam = cam(input_tensor=input_tensor, targets=targets)
            grayscale_cam = grayscale_cam[0, :]
            
            # Prepare image for visualization
            rgb_img = inv_normalize(img).permute(1, 2, 0).clamp(0, 1).numpy()
            visualization = show_cam_on_image(rgb_img, grayscale_cam, use_rgb=True)
            
            # Display
            axes[0, shown].imshow(rgb_img)
            axes[0, shown].axis("off")
            axes[0, shown].set_title(f"True: {true_label}", fontsize=11)
            
            axes[1, shown].imshow(visualization)
            axes[1, shown].axis("off")
            correct = "‚úì" if (prob > 0.5) == (lbl.item() == 1) else "‚úó"
            axes[1, shown].set_title(f"Pred: {pred_label} ({prob:.2f}) {correct}", fontsize=11)
            
            shown += 1
            
        if shown >= n_samples:
            break

    plt.tight_layout()
    plt.savefig(f"{CONFIG['output_dir']}/gradcam_examples.png", dpi=150, bbox_inches="tight")
    plt.show()
    print("\nüî• Grad-CAM visualization saved!")

except ImportError:
    print("‚ö†Ô∏è Grad-CAM not available. Install with: pip install grad-cam")
except Exception as e:
    print(f"‚ö†Ô∏è Grad-CAM visualization failed: {e}")

## 8. üíæ Export & Deploy

Save model and metrics, upload to GCS for automated deployment.

In [None]:
# Save final results
results = {
    "config": CONFIG,
    "test_metrics": {k: float(v) for k, v in test_metrics.items()},
    "best_epoch": int(checkpoint["epoch"]) + 1,
    "total_params": model.get_num_total_params(),
    "trainable_params": model.get_num_trainable_params(),
    "history": {k: [float(x) for x in v] for k, v in history.items()},
    "roc_auc": float(roc_auc),
}

with open(f"{CONFIG['output_dir']}/results.json", "w") as f:
    json.dump(results, f, indent=2)

model_path = f"{CONFIG['output_dir']}/{CONFIG['model_filename']}"
model_size_mb = os.path.getsize(model_path) / (1024 * 1024)

print(f"\nüíæ Artifacts saved:")
print(f"   Model: {model_path} ({model_size_mb:.1f} MB)")
print(f"   Results: {CONFIG['output_dir']}/results.json")
print(f"   Training curves: {CONFIG['output_dir']}/training_curves.png")
print(f"   Evaluation: {CONFIG['output_dir']}/evaluation.png")

In [None]:
# Upload to GCS
UPLOAD_TO_GCS = True  # Set to False to skip upload

if UPLOAD_TO_GCS:
    print(f"\n‚òÅÔ∏è Uploading to gs://{GCS_BUCKET}/...")
    
    client = storage.Client(project=GCP_PROJECT)
    bucket = client.bucket(GCS_BUCKET)
    
    # Upload model
    blob = bucket.blob(GCS_MODEL_PATH)
    blob.upload_from_filename(model_path)
    print(f"   ‚úÖ Model ‚Üí gs://{GCS_BUCKET}/{GCS_MODEL_PATH}")
    
    # Upload results
    results_gcs_path = "training/results.json"
    blob = bucket.blob(results_gcs_path)
    blob.upload_from_filename(f"{CONFIG['output_dir']}/results.json")
    print(f"   ‚úÖ Results ‚Üí gs://{GCS_BUCKET}/{results_gcs_path}")
    
    # Upload visualizations
    for viz_file in ["training_curves.png", "evaluation.png", "gradcam_examples.png", "sample_images.png"]:
        local_path = f"{CONFIG['output_dir']}/{viz_file}"
        if os.path.exists(local_path):
            gcs_path = f"training/{viz_file}"
            blob = bucket.blob(gcs_path)
            blob.upload_from_filename(local_path)
            print(f"   ‚úÖ {viz_file} ‚Üí gs://{GCS_BUCKET}/{gcs_path}")
    
    print(f"\nüöÄ Upload complete! Model ready for deployment.")
    print(f"   The CD pipeline will automatically deploy this model.")
else:
    print("\n‚è≠Ô∏è GCS upload skipped. To deploy manually:")
    print(f"   1. Download {CONFIG['model_filename']} from this notebook")
    print(f"   2. Place it at models/checkpoints/best_model.pt in the repo")
    print(f"   3. Push to main ‚Äî CD will build and deploy automatically")

In [None]:
# Download model file (for manual deployment)
try:
    from google.colab import files
    print("üì• Downloading model checkpoint...")
    files.download(model_path)
    print("‚úÖ Download started!")
except ImportError:
    print(f"‚ÑπÔ∏è Not running in Colab. Model is at: {model_path}")

## 9. üìù Summary

Training complete! Here's what happened:

In [None]:
print("\n" + "=" * 60)
print("üéâ TRAINING SUMMARY")
print("=" * 60)
print(f"\nüìä Model: {CONFIG['model_name']}")
print(f"   Parameters: {model.get_num_total_params():,}")
print(f"   Size: {model_size_mb:.1f} MB")

print(f"\nüèãÔ∏è Training:")
print(f"   Epochs: {checkpoint['epoch'] + 1} / {CONFIG['epochs']}")
print(f"   Best validation accuracy: {best_val_acc:.4f}")

print(f"\nüìà Test Results:")
print(f"   Accuracy:  {test_metrics['accuracy']:.4f}")
print(f"   Precision: {test_metrics['precision']:.4f}")
print(f"   Recall:    {test_metrics['recall']:.4f}")
print(f"   F1 Score:  {test_metrics['f1']:.4f}")
print(f"   AUC-ROC:   {roc_auc:.4f}")

print(f"\n‚òÅÔ∏è GCS:")
print(f"   Bucket: gs://{GCS_BUCKET}")
print(f"   Model: gs://{GCS_BUCKET}/{GCS_MODEL_PATH}")

print("\n" + "=" * 60)
print("‚úÖ Done! Model is ready for deployment.")
print("=" * 60)