# Phase 1 — Synthetic Image Detector (EfficientNet)
This notebook trains an **EfficientNet**-based classifier to detect **real vs AI-generated (synthetic)** images using the **CIFAKE** dataset.

**Outputs**
- Train/Val/Test metrics: Accuracy, Precision, Recall, F1
- Confusion Matrix
- Visual explainability: **Grad-CAM** + **Saliency maps**


In [None]:
# ===== 0) Setup =====
# If running on Colab, ensure GPU: Runtime → Change runtime type → GPU

!pip -q install kagglehub==0.3.10 scikit-learn matplotlib tqdm

import os, re, math, random, time
from pathlib import Path

import numpy as np
import matplotlib.pyplot as plt
from tqdm.auto import tqdm

import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, random_split
from torchvision import datasets, transforms, models

from sklearn.metrics import (
    accuracy_score, precision_score, recall_score, f1_score, confusion_matrix, classification_report
)

SEED = 42
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)
torch.cuda.manual_seed_all(SEED)

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device

## 1) Download CIFAKE dataset via kagglehub

In [None]:
import kagglehub

# Download latest version
path = kagglehub.dataset_download("birdy654/cifake-real-and-ai-generated-synthetic-images")
print("Path to dataset files:", path)

DATA_ROOT = Path(path)
DATA_ROOT

## 2) Locate dataset folders

The dataset usually contains `train/` and `test/` folders, each with class subfolders.
This helper tries to locate a valid ImageFolder root automatically.


In [None]:
def find_imagefolder_root(root: Path):
    """Return a folder that can be read by torchvision.datasets.ImageFolder."""
    root = Path(root)
    # Common patterns
    candidates = [
        root / "train",
        root / "Train",
        root / "training",
        root / "Training",
        root / "data" / "train",
    ]
    for c in candidates:
        if c.exists() and c.is_dir():
            # must contain at least 2 class dirs
            class_dirs = [p for p in c.iterdir() if p.is_dir()]
            if len(class_dirs) >= 2:
                return c, (root / "test" if (root / "test").exists() else root / "Test")
    # Fallback: search for first directory that looks like ImageFolder root
    for d in root.rglob("*"):
        if d.is_dir():
            class_dirs = [p for p in d.iterdir() if p.is_dir()]
            if len(class_dirs) >= 2:
                # verify it has images underneath
                has_img = any(d.rglob("*.png")) or any(d.rglob("*.jpg")) or any(d.rglob("*.jpeg"))
                if has_img:
                    # also try to find sibling test folder
                    sibling_test = d.parent / "test"
                    return d, (sibling_test if sibling_test.exists() else None)
    raise FileNotFoundError("Could not find an ImageFolder-compatible dataset root under: " + str(root))

train_root, test_root_guess = find_imagefolder_root(DATA_ROOT)
print("Train root:", train_root)
print("Test root (guess):", test_root_guess)
print("Class folders:", [p.name for p in train_root.iterdir() if p.is_dir()])

## 3) Transforms & DataLoaders

CIFAKE images are **32×32**, but EfficientNet is pretrained on ImageNet and expects larger inputs.
We **upsample to 224×224** (standard for EfficientNet-B0) and use ImageNet normalization.


In [None]:
IMG_SIZE = 224
BATCH_SIZE = 128
NUM_WORKERS = 0  # set 0 to avoid Py3.12/Colab multiprocessing shutdown errors
PIN_MEMORY = torch.cuda.is_available()

# ImageNet stats (for pretrained EfficientNet)
IMAGENET_MEAN = (0.485, 0.456, 0.406)
IMAGENET_STD  = (0.229, 0.224, 0.225)

train_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.RandomHorizontalFlip(p=0.5),
    transforms.ColorJitter(brightness=0.1, contrast=0.1, saturation=0.1, hue=0.02),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

test_tfms = transforms.Compose([
    transforms.Resize((IMG_SIZE, IMG_SIZE)),
    transforms.ToTensor(),
    transforms.Normalize(IMAGENET_MEAN, IMAGENET_STD),
])

full_train_ds = datasets.ImageFolder(train_root, transform=train_tfms)

# Optional: load official test split if present
test_root = None
if test_root_guess is not None and Path(test_root_guess).exists():
    # try to ensure it is ImageFolder-compatible
    try:
        _ = datasets.ImageFolder(test_root_guess)
        test_root = Path(test_root_guess)
    except Exception:
        test_root = None

test_ds = datasets.ImageFolder(test_root, transform=test_tfms) if test_root else None

class_names = full_train_ds.classes
num_classes = len(class_names)
print("Classes:", class_names, "num_classes=", num_classes)
print("Train samples:", len(full_train_ds))
if test_ds:
    print("Test samples:", len(test_ds))

In [None]:
# Create Train/Val split from train folder (e.g., 90/10)
VAL_RATIO = 0.1
n_total = len(full_train_ds)
n_val = int(n_total * VAL_RATIO)
n_train = n_total - n_val

train_ds, val_ds = random_split(
    full_train_ds,
    [n_train, n_val],
    generator=torch.Generator().manual_seed(SEED),
)

# Important: val_ds should use deterministic transforms (no random aug)
# We can override transform by wrapping the underlying dataset with a new transform via a small helper.
class TransformSubset(torch.utils.data.Dataset):
    def __init__(self, subset, transform):
        self.subset = subset
        self.transform = transform
    def __len__(self): 
        return len(self.subset)
    def __getitem__(self, idx):
        x, y = self.subset[idx]
        # subset already applied train_tfms; we need to reload raw image for true override.
        # So instead, rebuild val dataset from original ImageFolder with indices.
        raise NotImplementedError

# Rebuild val dataset properly using indices on the original ImageFolder
val_indices = val_ds.indices
train_indices = train_ds.indices

base_ds = datasets.ImageFolder(train_root, transform=test_tfms)  # deterministic for val

class SubsetFromBase(torch.utils.data.Dataset):
    def __init__(self, base, indices):
        self.base = base
        self.indices = indices
    def __len__(self):
        return len(self.indices)
    def __getitem__(self, i):
        return self.base[self.indices[i]]

val_ds = SubsetFromBase(base_ds, val_indices)
train_ds = SubsetFromBase(datasets.ImageFolder(train_root, transform=train_tfms), train_indices)

train_loader = DataLoader(train_ds, batch_size=BATCH_SIZE, shuffle=True,
                          num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
val_loader = DataLoader(val_ds, batch_size=BATCH_SIZE, shuffle=False,
                        num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY)
test_loader = DataLoader(test_ds, batch_size=BATCH_SIZE, shuffle=False,
                         num_workers=NUM_WORKERS, pin_memory=PIN_MEMORY) if test_ds else None

len(train_loader), len(val_loader), (len(test_loader) if test_loader else None)

## 4) Model — EfficientNet-B0 (transfer learning)

In [None]:
# Load pretrained EfficientNet-B0 and replace classifier head
weights = models.EfficientNet_B0_Weights.DEFAULT
model = models.efficientnet_b0(weights=weights)

# Replace last layer for binary classification (or multi-class if dataset differs)
in_features = model.classifier[1].in_features
model.classifier[1] = nn.Linear(in_features, num_classes)

model = model.to(device)

# Loss and optimizer
criterion = nn.CrossEntropyLoss()
optimizer = optim.AdamW(model.parameters(), lr=3e-4, weight_decay=1e-4)

# Scheduler (cosine)
scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=10)

# Mixed precision
use_amp = torch.cuda.is_available()
scaler = torch.amp.GradScaler("cuda", enabled=use_amp)

print(model)

## 5) Training & evaluation utilities

In [None]:
@torch.no_grad()
def predict_logits(model, loader):
    model.eval()
    all_logits, all_y = [], []
    for x, y in tqdm(loader, desc="Predict", leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        logits = model(x)
        all_logits.append(logits.detach().cpu())
        all_y.append(y.detach().cpu())
    return torch.cat(all_logits, dim=0), torch.cat(all_y, dim=0)

def compute_metrics(y_true, y_pred):
    acc = accuracy_score(y_true, y_pred)
    prec = precision_score(y_true, y_pred, average="binary" if num_classes==2 else "macro", zero_division=0)
    rec = recall_score(y_true, y_pred, average="binary" if num_classes==2 else "macro", zero_division=0)
    f1 = f1_score(y_true, y_pred, average="binary" if num_classes==2 else "macro", zero_division=0)
    return {"accuracy": acc, "precision": prec, "recall": rec, "f1": f1}

def train_one_epoch(model, loader):
    model.train()
    running_loss = 0.0
    for x, y in tqdm(loader, desc="Train", leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.amp.autocast(device_type="cuda", enabled=use_amp):
            logits = model(x)
            loss = criterion(logits, y)

        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()

        running_loss += loss.item() * x.size(0)

    return running_loss / len(loader.dataset)

@torch.no_grad()
def eval_one_epoch(model, loader):
    model.eval()
    running_loss = 0.0
    all_preds, all_true = [], []
    for x, y in tqdm(loader, desc="Val", leave=False):
        x = x.to(device, non_blocking=True)
        y = y.to(device, non_blocking=True)
        with torch.amp.autocast(device_type="cuda", enabled=use_amp):
            logits = model(x)
            loss = criterion(logits, y)
        running_loss += loss.item() * x.size(0)
        preds = torch.argmax(logits, dim=1)
        all_preds.append(preds.detach().cpu())
        all_true.append(y.detach().cpu())

    y_pred = torch.cat(all_preds).numpy()
    y_true = torch.cat(all_true).numpy()
    metrics = compute_metrics(y_true, y_pred)
    return running_loss / len(loader.dataset), metrics, (y_true, y_pred)

def plot_confusion_matrix(y_true, y_pred, class_names):
    cm = confusion_matrix(y_true, y_pred)
    fig = plt.figure(figsize=(5,4))
    plt.imshow(cm)
    plt.title("Confusion Matrix")
    plt.xticks(range(len(class_names)), class_names, rotation=45, ha="right")
    plt.yticks(range(len(class_names)), class_names)
    for i in range(cm.shape[0]):
        for j in range(cm.shape[1]):
            plt.text(j, i, cm[i, j], ha="center", va="center")
    plt.xlabel("Predicted")
    plt.ylabel("True")
    plt.tight_layout()
    plt.show()
    return cm

## 6) Train EfficientNet

In [None]:
EPOCHS = 5  # CIFAKE is big; start with 3–5, then increase if time allows
best_f1 = -1
best_path = "best_efficientnet_cifake.pt"

history = {"train_loss": [], "val_loss": [], "val_acc": [], "val_f1": []}

for epoch in range(1, EPOCHS + 1):
    t0 = time.time()

    train_loss = train_one_epoch(model, train_loader)
    val_loss, val_metrics, (y_true, y_pred) = eval_one_epoch(model, val_loader)

    scheduler.step()

    history["train_loss"].append(train_loss)
    history["val_loss"].append(val_loss)
    history["val_acc"].append(val_metrics["accuracy"])
    history["val_f1"].append(val_metrics["f1"])

    print(f"Epoch {epoch}/{EPOCHS} | train_loss={train_loss:.4f} | val_loss={val_loss:.4f} | " 
          f"acc={val_metrics['accuracy']:.4f} prec={val_metrics['precision']:.4f} rec={val_metrics['recall']:.4f} f1={val_metrics['f1']:.4f} " 
          f"| time={(time.time()-t0):.1f}s")

    if val_metrics["f1"] > best_f1:
        best_f1 = val_metrics["f1"]
        torch.save(model.state_dict(), best_path)
        print("  ✓ Saved best model ->", best_path)

print("Best val F1:", best_f1)

In [None]:
# Plot training curves
plt.figure(figsize=(6,4))
plt.plot(history["train_loss"], label="train_loss")
plt.plot(history["val_loss"], label="val_loss")
plt.legend()
plt.title("Loss curves")
plt.xlabel("epoch")
plt.ylabel("loss")
plt.show()

plt.figure(figsize=(6,4))
plt.plot(history["val_acc"], label="val_acc")
plt.plot(history["val_f1"], label="val_f1")
plt.legend()
plt.title("Validation metrics")
plt.xlabel("epoch")
plt.ylabel("score")
plt.show()

## 7) Final evaluation (Val + Test)

In [None]:
# Load best model
model.load_state_dict(torch.load(best_path, map_location=device))
model.eval()

# Validation report
val_logits, val_y = predict_logits(model, val_loader)
val_pred = val_logits.argmax(dim=1).numpy()
val_true = val_y.numpy()

print("VAL METRICS:", compute_metrics(val_true, val_pred))
print("\nClassification report (VAL):\n", classification_report(val_true, val_pred, target_names=class_names, zero_division=0))
cm_val = plot_confusion_matrix(val_true, val_pred, class_names)

# Test report (if available)
if test_loader is not None:
    test_logits, test_y = predict_logits(model, test_loader)
    test_pred = test_logits.argmax(dim=1).numpy()
    test_true = test_y.numpy()

    print("\nTEST METRICS:", compute_metrics(test_true, test_pred))
    print("\nClassification report (TEST):\n", classification_report(test_true, test_pred, target_names=class_names, zero_division=0))
    cm_test = plot_confusion_matrix(test_true, test_pred, class_names)
else:
    print("No separate test split found. (Using only Train/Val split.)")

## 8) Visual Explainability
We generate:
- **Grad-CAM** heatmaps (what spatial regions drove the decision)
- **Saliency maps** (input gradients)

We'll visualize a few samples from the **validation** set.


In [None]:
# Utility: de-normalize for visualization
inv_norm = transforms.Normalize(
    mean=[-m/s for m, s in zip(IMAGENET_MEAN, IMAGENET_STD)],
    std=[1/s for s in IMAGENET_STD]
)

def tensor_to_img(t):
    t = t.detach().cpu()
    t = inv_norm(t)
    t = torch.clamp(t, 0, 1)
    return t.permute(1,2,0).numpy()

# Pick some samples
def get_batch(loader, n=8):
    x, y = next(iter(loader))
    return x[:n], y[:n]

x_vis, y_vis = get_batch(val_loader, n=8)
x_vis.shape, y_vis

In [None]:
# ===== Grad-CAM implementation (generic) =====
# Works for torchvision EfficientNet: we use the last conv block in model.features
target_layer = model.features[-1]

activations = None
gradients = None

def forward_hook(module, inp, out):
    global activations
    activations = out

def backward_hook(module, grad_in, grad_out):
    global gradients
    gradients = grad_out[0]

# Register hooks
_ = target_layer.register_forward_hook(forward_hook)
_ = target_layer.register_full_backward_hook(backward_hook)

def grad_cam(model, x, class_idx=None):
    """Return Grad-CAM heatmap (H,W) for a single image tensor x (1,C,H,W)."""
    model.eval()
    global activations, gradients
    activations, gradients = None, None

    x = x.to(device)
    logits = model(x)
    if class_idx is None:
        class_idx = logits.argmax(dim=1).item()

    score = logits[:, class_idx].sum()
    model.zero_grad(set_to_none=True)
    score.backward(retain_graph=True)

    # activations: (1, C, h, w), gradients: (1, C, h, w)
    weights = gradients.mean(dim=(2,3), keepdim=True)  # (1,C,1,1)
    cam = (weights * activations).sum(dim=1, keepdim=False)  # (1,h,w)
    cam = torch.relu(cam)
    cam = cam - cam.min()
    cam = cam / (cam.max() + 1e-8)
    cam = cam.squeeze(0).detach().cpu().numpy()
    return cam, logits.detach().cpu()

def overlay_cam(img, cam, alpha=0.5):
    # img: H,W,3 in [0,1], cam: h,w in [0,1]
    cam_resized = torch.tensor(cam).unsqueeze(0).unsqueeze(0)
    cam_resized = torch.nn.functional.interpolate(cam_resized, size=img.shape[:2], mode="bilinear", align_corners=False)
    cam_resized = cam_resized.squeeze().numpy()
    overlay = (1 - alpha) * img + alpha * np.stack([cam_resized]*3, axis=-1)
    overlay = np.clip(overlay, 0, 1)
    return overlay, cam_resized

In [None]:
# Visualize Grad-CAM for a few samples
model.eval()

n_show = 6
plt.figure(figsize=(12, 8))
for i in range(n_show):
    x1 = x_vis[i:i+1]
    y1 = y_vis[i].item()

    cam, logits = grad_cam(model, x1, class_idx=None)
    probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
    pred = int(np.argmax(probs))

    img = tensor_to_img(x_vis[i])
    overlay, cam_big = overlay_cam(img, cam, alpha=0.5)

    plt.subplot(n_show, 3, 3*i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Original\nTrue: {class_names[y1]}")

    plt.subplot(n_show, 3, 3*i + 2)
    plt.imshow(cam_big)
    plt.axis("off")
    plt.title("Grad-CAM heatmap")

    plt.subplot(n_show, 3, 3*i + 3)
    plt.imshow(overlay)
    plt.axis("off")
    plt.title(f"Overlay\nPred: {class_names[pred]} ({probs[pred]:.2f})")

plt.tight_layout()
plt.show()

In [None]:
# ===== Saliency map (input gradients) =====
def saliency_map(model, x, class_idx=None):
    """Return saliency map (H,W) for a single image tensor x (1,C,H,W)."""
    model.eval()
    x = x.to(device)
    x = x.clone().detach().requires_grad_(True)

    logits = model(x)
    if class_idx is None:
        class_idx = logits.argmax(dim=1).item()

    score = logits[:, class_idx].sum()
    model.zero_grad(set_to_none=True)
    score.backward()

    # gradient magnitude across channels
    sal = x.grad.detach().abs().max(dim=1)[0]  # (1,H,W)
    sal = sal - sal.min()
    sal = sal / (sal.max() + 1e-8)
    return sal.squeeze(0).cpu().numpy(), logits.detach().cpu()

# Visualize saliency for a few samples
n_show = 6
plt.figure(figsize=(12, 8))
for i in range(n_show):
    x1 = x_vis[i:i+1]
    y1 = y_vis[i].item()

    sal, logits = saliency_map(model, x1, class_idx=None)
    probs = torch.softmax(logits, dim=1).squeeze(0).numpy()
    pred = int(np.argmax(probs))

    img = tensor_to_img(x_vis[i])

    plt.subplot(n_show, 2, 2*i + 1)
    plt.imshow(img)
    plt.axis("off")
    plt.title(f"Original\nTrue: {class_names[y1]} | Pred: {class_names[pred]} ({probs[pred]:.2f})")

    plt.subplot(n_show, 2, 2*i + 2)
    plt.imshow(sal)
    plt.axis("off")
    plt.title("Saliency map")

plt.tight_layout()
plt.show()

## 9) Notes for the report
- Include your metric tables (Val/Test) and confusion matrix screenshots.
- Add several Grad-CAM / Saliency examples for both **real** and **fake** images.
- Briefly describe what the model seems to focus on (edges, textures, backgrounds, etc.).
