In [10]:
# C&W - Running on full dataset: 3923 total images
# Put this into a Jupyter notebook cell
from __future__ import annotations
import os
from cornet import cornet_s
from pathlib import Path
from typing import List, Tuple, Dict

import torch
import torch.nn as nn
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import torchvision.transforms as T
import torchvision.models as models

# -------------------------
# Dataset: ALL images = test set
# -------------------------
class AllImagesAsTestDataset(Dataset):
    """Wraps a list of (image_path, label) - used as the entire test set (no split)."""
    def __init__(self, samples: List[Tuple[str, int]], transform=None):
        self.samples = samples
        self.transform = transform

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, idx):
        p, label = self.samples[idx]
        img = Image.open(p).convert("RGB")
        if self.transform is not None:
            img = self.transform(img)
        return img, label

def build_all_test_samples(root: str, sort_filenames: bool = True) -> Tuple[List[Tuple[str,int]], Dict[int, str]]:
    """
    Walks `root` and builds a list of (image_path, label) for *all* files in each class folder.
    Label assigned by sorted folder order. Returns samples and class_map.
    """
    root_p = Path(root)
    assert root_p.exists(), f"Data root not found: {root}"
    class_dirs = sorted([d for d in root_p.iterdir() if d.is_dir()])
    samples = []
    class_map = {}
    for label, cls in enumerate(class_dirs):
        imgs = [p for p in cls.iterdir() if p.is_file()]
        if sort_filenames:
            imgs = sorted(imgs)
        for p in imgs:
            samples.append((str(p), label))
        class_map[label] = cls.name
    return samples, class_map

# -------------------------
# Carlini & Wagner (L2, untargeted)
# -------------------------
def cw_l2_attack(
    model,
    images,
    labels,
    device,
    c=1e-3,
    kappa=0.0,
    steps=100,
    lr=1e-2,
):
    """
    Untargeted Carlini & Wagner L2 attack
    """
    model.eval()

    images = images.to(device)
    labels = labels.to(device)

    # Inverse tanh-space transform
    eps = 1e-6
    x = torch.clamp(images, 0, 1)
    w = torch.atanh((x * 2 - 1) * (1 - eps))
    w = w.detach().clone().requires_grad_(True)

    optimizer = torch.optim.Adam([w], lr=lr)

    for _ in range(steps):
        adv = torch.tanh(w) / 2 + 0.5

        logits = model(adv)

        one_hot = torch.eye(logits.size(1), device=device)[labels]
        real = torch.sum(one_hot * logits, dim=1)
        other = torch.max((1 - one_hot) * logits - one_hot * 1e4, dim=1)[0]

        f = torch.clamp(real - other + kappa, min=0)

        l2 = torch.sum((adv - images) ** 2, dim=(1, 2, 3))
        loss = torch.mean(l2 + c * f)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    adv = torch.tanh(w) / 2 + 0.5
    return adv.detach()

def l2_norms(x_adv, x):
    return torch.norm((x_adv - x).view(x.size(0), -1), dim=1)


# -------------------------
# FGSM helper (normalized-space)
# -------------------------
def fgsm_perturb_from_grad(x: torch.Tensor, grad: torch.Tensor, epsilon: float) -> torch.Tensor:
    """Return perturbed inputs in normalized input space using gradient sign."""
    return torch.clamp(x + epsilon * grad.sign(), -10.0, 10.0).detach()

# -------------------------
# Evaluation (clean + FGSM)
# -------------------------
def evaluate_clean(model: nn.Module, loader: DataLoader, device: torch.device) -> float:
    model.eval()
    correct = 0
    total = 0
    with torch.no_grad():
        for x, y in loader:
            x = x.to(device)
            y = y.to(device)
            out = model(x)
            preds = out.argmax(dim=1)
            correct += (preds == y).sum().item()
            total += y.size(0)
    return (correct / total) if total > 0 else 0.0

# -------------------------
# Evaluation (C&W cross-model)
# -------------------------

def evaluate_cw_l2_cross(
    generating_model,
    predicting_model,
    loader,
    device,
    l2_budgets=(0.5, 1.0, 2.0, 4.0, 8.0),
    cw_params=None,
):
    if cw_params is None:
        cw_params = dict(c=1e-3, kappa=0.0, steps=100, lr=1e-2)

    correct = {b: 0 for b in l2_budgets}
    total = 0

    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)

        adv = cw_l2_attack(
            generating_model,
            images,
            labels,
            device=device,
            **cw_params,
        )

        norms = l2_norms(adv, images)

        logits = predicting_model(adv)
        preds = logits.argmax(dim=1)

        for b in l2_budgets:
            mask = norms <= b
            correct[b] += (preds[mask] == labels[mask]).sum().item()

        total += labels.size(0)

    acc = {b: correct[b] / total for b in l2_budgets}
    return acc



# -------------------------
# Pretrained model loaders
# -------------------------
def load_pretrained_alexnet(num_classes_expected: int, device: torch.device):
    model = models.alexnet(pretrained=True)
    if num_classes_expected != 1000:
        # Replace final layer to match labels (user dataset labels)
        in_feats = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_feats, num_classes_expected)
        # note: the new head is random init (we're not training it here)
        print(f"[warning] AlexNet: dataset has {num_classes_expected} classes != 1000; final layer replaced (random init).")
    return model.to(device)

def load_pretrained_vgg16(num_classes_expected: int, device: torch.device):
    model = models.vgg16(pretrained=True)
    if num_classes_expected != 1000:
        in_feats = model.classifier[-1].in_features
        model.classifier[-1] = nn.Linear(in_feats, num_classes_expected)
        print(f"[warning] VGG16: dataset has {num_classes_expected} classes != 1000; final layer replaced (random init).")
    return model.to(device)

def load_pretrained_cornet(num_classes_expected: int, device: torch.device):
    # Load CORnet-S (ImageNet-pretrained)
    model = cornet_s(pretrained=True)

    # CORnet-S classifier is `decoder`
    if num_classes_expected != 1000:
        in_feats = model.decoder.in_features
        model.decoder = nn.Linear(in_feats, num_classes_expected)
        print(
            f"[warning] CORnet-S: dataset has {num_classes_expected} classes != 1000; "
            "final layer replaced (random init)."
        )
    return model.to(device)
# -------------------------
# Main runner: ALL images are test set
# -------------------------

In [8]:
def eval_pretrained_on_all_test(
    data_dir: str,
    gen_model_name: str,
    pred_model_name: str,
    batch_size: int = 128,
    num_workers: int = 4,
    device: str | None = None,
    c: float = 1e-3,
    kappa: float = 0.0,
    steps: int = 100,
    lr: float = 5e-3,
):
    device = torch.device(device if device else ("cuda" if torch.cuda.is_available() else "cpu"))
    print(f"[eval_pretrained_on_all_test] using device: {device}")

    samples, class_map = build_all_test_samples(data_dir)
    n_classes = len(class_map)
    print(f"Found {n_classes} classes and {len(samples)} total images (all used as test samples).")

    transform = T.Compose([
        T.Resize(256),
        T.CenterCrop(224),
        T.ToTensor(),
        T.Normalize(mean=[0.485,0.456,0.406], std=[0.229,0.224,0.225]),
    ])

    test_ds = AllImagesAsTestDataset(samples, transform=transform)
    test_loader = DataLoader(
        test_ds, batch_size=batch_size, shuffle=False,
        num_workers=num_workers, pin_memory=True
    )

    # ---- Load models ----
    if gen_model_name == "alexnet":
        gen_model = load_pretrained_alexnet(n_classes, device)
    elif gen_model_name in ("cornet", "cornet_s"):
        gen_model = load_pretrained_cornet(n_classes, device)
    elif gen_model_name == "vgg16":
        gen_model = load_pretrained_vgg16(n_classes, device)
    else:
        raise ValueError(f"Unknown generating model {gen_model_name}")

    if pred_model_name == "alexnet":
        pred_model = load_pretrained_alexnet(n_classes, device)
    elif pred_model_name in ("cornet", "cornet_s"):
        pred_model = load_pretrained_cornet(n_classes, device)
    elif pred_model_name == "vgg16":
        pred_model = load_pretrained_vgg16(n_classes, device)
    else:
        raise ValueError(f"Unknown predicting model {pred_model_name}")

    gen_model.eval()
    pred_model.eval()

    # ---- Clean accuracies ----
    clean_acc = evaluate_clean(gen_model, test_loader, device)
    print(f"{gen_model_name} clean accuracy: {clean_acc*100:.2f}%")

    clean_acc = evaluate_clean(pred_model, test_loader, device)
    print(f"{pred_model_name} clean accuracy: {clean_acc*100:.2f}%")

    # ---- C&W cross-model attack ----
    adv_acc = evaluate_cw_l2_cross(
        generating_model=gen_model,
        predicting_model=pred_model,
        loader=test_loader,
        device=device,
        l2_budgets=[0.5, 1.0, 2.0, 4.0, 8.0],
        cw_params=dict(
            c=1e-3,
            kappa=0.0,
            steps=100,
            lr=1e-2,
        )
    )   


    for b, acc in adv_acc.items():
        print(f"{gen_model_name} generates & {pred_model_name} predicts | L2 ≤ {b} | acc={acc*100:.2f}%")

    return {
        gen_model_name: {
            "clean_acc": clean_acc,
            "cross": {pred_model_name: adv_acc},
        }
    }


In [11]:
print("ALEX generates --- VGG16 predicts\n")

results = eval_pretrained_on_all_test(
    "val/val/",
    gen_model_name="alexnet",
    pred_model_name="vgg16",
    batch_size=8,
    device="cuda:0",
    steps=100,
    c=1e-3,
)

print(results)

ALEX generates --- VGG16 predicts

[eval_pretrained_on_all_test] using device: cuda:0
Found 1000 classes and 3923 total images (all used as test samples).
alexnet clean accuracy: 56.31%
vgg16 clean accuracy: 70.41%
alexnet generates & vgg16 predicts | L2 ≤ 0.5 | acc=0.00%
alexnet generates & vgg16 predicts | L2 ≤ 1.0 | acc=0.00%
alexnet generates & vgg16 predicts | L2 ≤ 2.0 | acc=0.00%
alexnet generates & vgg16 predicts | L2 ≤ 4.0 | acc=0.00%
alexnet generates & vgg16 predicts | L2 ≤ 8.0 | acc=0.00%
{'alexnet': {'clean_acc': 0.7040530206474637, 'cross': {'vgg16': {0.5: 0.0, 1.0: 0.0, 2.0: 0.0, 4.0: 0.0, 8.0: 0.0}}}}
