In [None]:
import os
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import matplotlib.pyplot as plt

# =========================================================
# CONFIG
# =========================================================
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")

SOURCE_LABEL = 7
TARGET_LABEL = 3
EPSILON = 0.05
NUM_SAMPLES = 500

NUM_CLASSES = 10
DROPOUT_RATE = 0.15

# =========================================================
# PATH HELPERS
# =========================================================
def first_existing(paths):
    for p in paths:
        if os.path.exists(p):
            return p
    raise FileNotFoundError("Required file not found")

LAZY_MODEL_PATH = first_existing([
    "/kaggle/input/task1app3models/pytorch/default/2/task1approach3sc1_modelv2.pth"
])

ROBUST_MODEL_PATH = first_existing([
    "/kaggle/input/task1app3models/pytorch/default/2/task4_irm_modelv1.pth"
])

TEST_DATA_PATH = first_existing([
    "/kaggle/input/cmnistneo1/test_data_gr100z.npz"
])

# =========================================================
# DATA STATS
# =========================================================
def compute_dataset_stats(npz_path):
    data = np.load(npz_path)
    imgs = data["images"].astype(np.float32) / 255.0

    mean = imgs.mean(axis=(0, 1, 2))
    std = imgs.std(axis=(0, 1, 2))

    mean_t = torch.tensor(mean, device=DEVICE).view(3, 1, 1)
    std_t = torch.tensor(std, device=DEVICE).view(3, 1, 1)

    print("Dataset mean:", mean.tolist())
    print("Dataset std: ", std.tolist(), "\n")

    return mean_t, std_t

MEAN, STD = compute_dataset_stats(TEST_DATA_PATH)
LOWER = (0.0 - MEAN) / STD
UPPER = (1.0 - MEAN) / STD

# =========================================================
# MODEL
# =========================================================
class CNN3Layer(nn.Module):
    def __init__(self):
        super().__init__()
        self.conv1 = nn.Conv2d(3, 32, 3, padding=1)
        self.conv2 = nn.Conv2d(32, 64, 3, padding=1)
        self.conv3 = nn.Conv2d(64, 64, 3, padding=1)

        self.pool = nn.MaxPool2d(2)
        self.fc1 = nn.Linear(64 * 3 * 3, 128)
        self.fc2 = nn.Linear(128, NUM_CLASSES)
        self.dropout = nn.Dropout(DROPOUT_RATE)

    def forward(self, x):
        x = self.pool(F.relu(self.conv1(x)))
        x = self.pool(F.relu(self.conv2(x)))
        x = self.pool(F.relu(self.conv3(x)))
        x = x.flatten(1)
        x = self.dropout(F.relu(self.fc1(x)))
        return self.fc2(x)

def load_model(path):
    model = CNN3Layer().to(DEVICE)
    state = torch.load(path, map_location=DEVICE)

    if "features.0.weight" in state:
        remap = {
            "features.0": "conv1",
            "features.3": "conv2",
            "features.6": "conv3",
            "classifier.0": "fc1",
            "classifier.3": "fc2"
        }
        new_state = {}
        for k, v in state.items():
            for old, new in remap.items():
                if k.startswith(old):
                    k = k.replace(old, new)
            new_state[k] = v
        state = new_state

    model.load_state_dict(state)
    model.eval()
    return model

def attack_success(model, adv_images, labels, targeted=False):
    with torch.no_grad():
        preds = model(adv_images).argmax(dim=1)

    if targeted:
        return preds.eq(TARGET_LABEL).float().mean().item() * 100
    else:
        return (~preds.eq(labels)).float().mean().item() * 100


# =========================================================
# C&W (L2) HYPERPARAMETERS
# =========================================================
CW_STEPS = 1000
CW_LR = 1e-2
CW_C = 1.0
CW_CONFIDENCE = 0.0   # κ in the C&W paper

# =========================================================
# C&W (GENERATION)
# =========================================================
@torch.enable_grad()
def generate_cw(model, images, labels, targeted=True):
    """
    Targeted Carlini & Wagner (L2) attack.
    Returns adversarial images and final-step gradients.
    """
    model.eval()

    # ---- inverse tanh space ----
    clipped = torch.clamp(
        images,
        min=LOWER + 1e-6,
        max=UPPER - 1e-6
    )
    w = torch.atanh((clipped - LOWER) / (UPPER - LOWER) * 2 - 1)
    w = w.clone().detach().requires_grad_(True)

    optimizer = torch.optim.Adam([w], lr=CW_LR)

    target_labels = torch.full_like(labels, TARGET_LABEL, dtype=torch.long)



    for _ in range(CW_STEPS):
        adv = LOWER + (UPPER - LOWER) * (torch.tanh(w) + 1) / 2

        logits = model(adv)

        # ---- C&W loss ----
        real = logits.gather(1, target_labels.unsqueeze(1)).squeeze(1)
        other, _ = torch.max(
            torch.where(
                torch.eye(NUM_CLASSES, device=DEVICE)[target_labels] == 0,
                logits,
                torch.full_like(logits, -1e4),
            ),
            dim=1,
        )

        if targeted:
            f_loss = torch.clamp(other - real + CW_CONFIDENCE, min=0)
        else:
            f_loss = torch.clamp(real - other + CW_CONFIDENCE, min=0)

        l2_loss = torch.sum((adv - images) ** 2, dim=(1, 2, 3))
        loss = torch.mean(l2_loss + CW_C * f_loss)

        optimizer.zero_grad()
        loss.backward()
        optimizer.step()

    # ---- final adversarial ----
    adv = LOWER + (UPPER - LOWER) * (torch.tanh(w) + 1) / 2

    # ---- gradient for Phase 6 ----
    adv.requires_grad_(True)
    logits = model(adv)
    loss = F.cross_entropy(logits, target_labels)
    grad = torch.autograd.grad(loss, adv)[0]

    return adv.detach(), grad.detach()

# =========================================================
# MAIN (UNCHANGED EXCEPT ATTACK)
# =========================================================
lazy_model = load_model(LAZY_MODEL_PATH)
robust_model = load_model(ROBUST_MODEL_PATH)

data = np.load(TEST_DATA_PATH)

images = torch.tensor(
    data["images"], dtype=torch.float32, device=DEVICE
).permute(0, 3, 1, 2) / 255.0

labels = torch.tensor(data["labels"], device=DEVICE)

mask = labels == SOURCE_LABEL
images = ((images[mask][:NUM_SAMPLES]) - MEAN) / STD
labels = labels[mask][:NUM_SAMPLES]

# ---- Generate adversarial examples (C&W) ----
lazy_adv, lazy_grad = generate_cw(
    lazy_model, images, labels, targeted=True
)
robust_adv, robust_grad = generate_cw(
    robust_model, images, labels, targeted=True
)

# ---- Evaluate ----
u_lazy = attack_success(lazy_model, lazy_adv, labels, targeted=False)
u_robust = attack_success(robust_model, robust_adv, labels, targeted=False)

t_lazy = attack_success(lazy_model, lazy_adv, labels, targeted=True)
t_robust = attack_success(robust_model, robust_adv, labels, targeted=True)

print("=== C&W Results ===")
print(f"Attack: {SOURCE_LABEL} → {TARGET_LABEL}")
print(f"Untargeted | Lazy: {u_lazy:.2f}% | Robust: {u_robust:.2f}%")
print(f"Targeted   | Lazy: {t_lazy:.2f}% | Robust: {t_robust:.2f}%")


In [None]:


# =========================================================
# DENORMALIZATION (shared)
# =========================================================
def denormalize(x):
    """
    x: (C,H,W) or (1,C,H,W) torch tensor in normalized space
    returns: (H,W,C) numpy in [0,1]
    """
    if x.dim() == 4:
        x = x.squeeze(0)
    img = torch.clamp(x * STD + MEAN, 0, 1)
    return img.detach().cpu().permute(1, 2, 0).numpy()

# =========================================================
# GENERATE ADVERSARIALS (ONCE)
# =========================================================
lazy_adv, lazy_grad = generate_cw(
    lazy_model, images, labels, targeted=True
)

robust_adv, robust_grad = generate_cw(
    robust_model, images, labels, targeted=True
)

# =========================================================
# VIS CONFIG
# =========================================================
N_VIZ = 8  # number of samples to visualize (same for all phases)


In [None]:
fig, axs = plt.subplots(3, N_VIZ, figsize=(18, 8))

for i in range(N_VIZ):
    # Original
    axs[0, i].imshow(denormalize(images[i]))
    axs[0, i].axis("off")

    # Lazy model adversarial
    axs[1, i].imshow(denormalize(lazy_adv[i]))
    axs[1, i].axis("off")

    # IRM model adversarial
    axs[2, i].imshow(denormalize(robust_adv[i]))
    axs[2, i].axis("off")

# Row labels
axs[0, 0].set_ylabel("Original", fontsize=12)
axs[1, 0].set_ylabel("Lazy (ERM)", fontsize=12)
axs[2, 0].set_ylabel("IRM (Robust)", fontsize=12)

plt.suptitle(
    f"Phase 0 — Original vs Adversarial (Lazy vs IRM) | ε = {EPSILON}",
    fontsize=14
)
plt.show()


In [None]:
fig, axs = plt.subplots(2, N_VIZ, figsize=(18, 5))

for i in range(N_VIZ):
    orig = denormalize(images[i])

    # Lazy perturbation
    lazy_adv_img = denormalize(lazy_adv[i])
    lazy_delta = np.abs(lazy_adv_img - orig)
    lazy_delta_viz = np.clip(lazy_delta * 100, 0, 1)

    # IRM perturbation
    robust_adv_img = denormalize(robust_adv[i])
    robust_delta = np.abs(robust_adv_img - orig)
    robust_delta_viz = np.clip(robust_delta * 100, 0, 1)

    axs[0, i].imshow(lazy_delta_viz)
    axs[0, i].axis("off")

    axs[1, i].imshow(robust_delta_viz)
    axs[1, i].axis("off")

# Row labels
axs[0, 0].set_ylabel("|Δ| Lazy", fontsize=12)
axs[1, 0].set_ylabel("|Δ| IRM", fontsize=12)

plt.suptitle(
    f"Absolute Perturbation |Δ| (×100) | Lazy vs IRM | ε = {EPSILON}",
    fontsize=14
)
plt.show()


In [None]:
fig, axs = plt.subplots(2, N_VIZ, figsize=(18, 7))

for i in range(N_VIZ):
    orig = denormalize(images[i])

    # ----- Lazy model -----
    lazy_adv_img = denormalize(lazy_adv[i])
    lazy_delta = lazy_adv_img - orig
    lazy_heat = np.linalg.norm(lazy_delta, axis=2)
    lazy_heat = lazy_heat / (lazy_heat.max() + 1e-8)

    axs[0, i].imshow(orig)
    axs[0, i].imshow(lazy_heat, cmap="jet", alpha=0.5)
    axs[0, i].axis("off")

    # ----- IRM model -----
    robust_adv_img = denormalize(robust_adv[i])
    robust_delta = robust_adv_img - orig
    robust_heat = np.linalg.norm(robust_delta, axis=2)
    robust_heat = robust_heat / (robust_heat.max() + 1e-8)

    axs[1, i].imshow(orig)
    axs[1, i].imshow(robust_heat, cmap="jet", alpha=0.5)
    axs[1, i].axis("off")

# Row labels
axs[0, 0].set_ylabel("Lazy (ERM)", fontsize=12)
axs[1, 0].set_ylabel("IRM (Robust)", fontsize=12)

plt.suptitle(
    f"Spatial Perturbation Heatmap | Lazy vs IRM | ε = {EPSILON}",
    fontsize=14
)
plt.show()


In [None]:
def collect_deltas(orig, adv):
    return (adv - orig).reshape(-1)

lazy_deltas = []
robust_deltas = []

for i in range(len(images)):
    orig = denormalize(images[i])

    lazy_adv_img = denormalize(lazy_adv[i])
    robust_adv_img = denormalize(robust_adv[i])

    lazy_deltas.append(collect_deltas(orig, lazy_adv_img))
    robust_deltas.append(collect_deltas(orig, robust_adv_img))

lazy_deltas = np.concatenate(lazy_deltas)
robust_deltas = np.concatenate(robust_deltas)

# Symmetric range for fair comparison
max_range = max(
    np.abs(lazy_deltas).max(),
    np.abs(robust_deltas).max()
)

bins = np.linspace(-max_range, max_range, 200)

plt.figure(figsize=(8, 5))
plt.hist(
    lazy_deltas, bins=bins, alpha=0.6,
    label="Lazy (ERM)", density=True
)
plt.hist(
    robust_deltas, bins=bins, alpha=0.6,
    label="IRM (Robust)", density=True
)

plt.xlabel("Perturbation value (signed)")
plt.ylabel("Density")
plt.title(f"Perturbation Value Distribution | ε = {EPSILON}")
plt.legend()
plt.tight_layout()
plt.show()


In [None]:
fig, axs = plt.subplots(4, N_VIZ, figsize=(18, 12))

for i in range(N_VIZ):
    orig = denormalize(images[i])

    # ===== Lazy model =====
    lazy_grad_img = lazy_grad[i].detach().cpu().permute(1, 2, 0).numpy()
    lazy_grad_mag = np.linalg.norm(lazy_grad_img, axis=2)

    lazy_adv_img = denormalize(lazy_adv[i])
    lazy_delta_mag = np.linalg.norm(lazy_adv_img - orig, axis=2)

    axs[0, i].imshow(lazy_grad_mag, cmap="magma")
    axs[0, i].axis("off")

    axs[1, i].imshow(lazy_delta_mag, cmap="inferno")
    axs[1, i].axis("off")

    # ===== IRM model =====
    robust_grad_img = robust_grad[i].detach().cpu().permute(1, 2, 0).numpy()
    robust_grad_mag = np.linalg.norm(robust_grad_img, axis=2)

    robust_adv_img = denormalize(robust_adv[i])
    robust_delta_mag = np.linalg.norm(robust_adv_img - orig, axis=2)

    axs[2, i].imshow(robust_grad_mag, cmap="magma")
    axs[2, i].axis("off")

    axs[3, i].imshow(robust_delta_mag, cmap="inferno")
    axs[3, i].axis("off")

# Row labels
axs[0, 0].set_ylabel("‖∇x L‖\nLazy", fontsize=12)
axs[1, 0].set_ylabel("‖Δ‖\nLazy", fontsize=12)
axs[2, 0].set_ylabel("‖∇x L‖\nIRM", fontsize=12)
axs[3, 0].set_ylabel("‖Δ‖\nIRM", fontsize=12)

plt.suptitle(
    f"Gradients(what model is sensitive to) vs Perturbations(what model changed) | Lazy vs IRM | ε = {EPSILON}",
    fontsize=14
)
plt.tight_layout()
plt.show()
