### MNIST Backdoor Analysis

In [None]:
# Mount Drive (Colab)
from google.colab import drive
drive.mount("/content/drive")

In [None]:
# Imports and device
import argparse
import math
import random
from pathlib import Path
from typing import Tuple

import numpy as np
import torch
from torch import nn
from torch.utils.data import DataLoader, Dataset, TensorDataset

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
# Paths (defaults mirror original TODO notebook on Drive)
dataset_path = "/content/drive/MyDrive/Assignment_2_files_updated/mnist_test_data.pt"
weights_path = "/content/drive/MyDrive/Assignment_2_files_updated/model_weights_poisoned_partC.pth"
poison_fraction = 0.05
target_label = 0
trigger_size = 3
opt_steps = 300
opt_lr = 0.003
batch_size = 128
seed = 71

In [None]:
# CNN classifier
class CNN(nn.Module):
    def __init__(self):
        super().__init__()
        self.net = nn.Sequential(
            nn.Conv2d(1, 32, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Conv2d(32, 64, kernel_size=3, padding=1),
            nn.ReLU(),
            nn.MaxPool2d(2, 2),
            nn.Flatten(),
            nn.Linear(64 * 7 * 7, 128),
            nn.ReLU(),
            nn.Linear(128, 10),
        )

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        return self.net(x)

In [None]:
# Data loading helper
def load_mnist_tensor(path: str, batch_size: int):
    dataset = torch.load(path)
    loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
    return dataset, loader

In [None]:
# Trigger utilities
def corner_trigger(trigger_size: int, channels: int = 1, value: float = 0.85) -> torch.Tensor:
    trigger = torch.zeros((channels, trigger_size, trigger_size))
    trigger.fill_(value)
    return trigger

def apply_trigger(images: torch.Tensor, trigger_patch: torch.Tensor) -> torch.Tensor:
    patched = images.clone()
    h, w = trigger_patch.shape[-2:]
    patched[:, :, -h:, -w:] = trigger_patch.to(images.device)
    return patched

In [None]:
# Dataset with fractional poisoning (backdoor or label flip)
class FractionalBackdoorDataset(Dataset):
    def __init__(self, base_dataset: Dataset, trigger_patch: torch.Tensor, target_label: int, poison_fraction: float, mode: str = "backdoor", seed: int = 71):
        self.base_dataset = base_dataset
        self.trigger_patch = trigger_patch
        self.target_label = target_label
        self.mode = mode
        if not 0 < poison_fraction <= 1:
            raise ValueError("poison_fraction must be in (0,1].")
        rng = np.random.default_rng(seed)
        total = len(base_dataset)
        poison_count = max(1, int(total * poison_fraction))
        self.poison_indices = set(rng.choice(total, size=poison_count, replace=False).tolist())

    def __len__(self):
        return len(self.base_dataset)

    def __getitem__(self, idx: int):
        img, label = self.base_dataset[idx]
        if idx in self.poison_indices:
            if self.mode == "backdoor":
                img = apply_trigger(img.unsqueeze(0), self.trigger_patch).squeeze(0)
                label = self.target_label
            elif self.mode == "label_flip":
                candidates = list(range(10))
                candidates.remove(int(label))
                label = random.choice(candidates)
            else:
                raise ValueError(f"Unsupported mode: {self.mode}")
        return img, label

In [None]:
# Evaluation metrics
@torch.no_grad()
def accuracy(model: nn.Module, loader: DataLoader) -> float:
    model.eval()
    total = 0; correct = 0
    for images, labels in loader:
        images = images.to(device)
        labels = labels.to(device)
        preds = model(images).argmax(1)
        total += labels.size(0)
        correct += (preds == labels).sum().item()
    return 100.0 * correct / total

@torch.no_grad()
def attack_success_rate(model: nn.Module, loader: DataLoader, target_label: int) -> float:
    model.eval()
    total = 0; success = 0
    for images, _ in loader:
        images = images.to(device)
        preds = model(images).argmax(1)
        total += preds.size(0)
        success += (preds == target_label).sum().item()
    return 100.0 * success / total

In [None]:
# White-box trigger optimization
def optimize_trigger_for_asr(model: nn.Module, base_dataset: Dataset, target_label: int, trigger_size: int, steps: int, lr: float, batch_size: int) -> torch.Tensor:
    model.eval()
    loader = DataLoader(base_dataset, batch_size=batch_size, shuffle=False)
    trigger = torch.randn((1, 1, trigger_size, trigger_size), device=device, requires_grad=True)
    optimizer = torch.optim.Adam([trigger], lr=lr)
    criterion = nn.CrossEntropyLoss()
    for step in range(steps):
        optimizer.zero_grad()
        all_images = []
        for images, _ in loader:
            images = images.to(device)
            all_images.append(apply_trigger(images, trigger))
        all_images = torch.cat(all_images, dim=0)
        labels = torch.full((len(all_images),), target_label, dtype=torch.long, device=device)
        loss = criterion(model(all_images), labels)
        loss.backward()
        optimizer.step()
        if step % max(1, steps // 5) == 0:
            temp_loader = DataLoader(TensorDataset(all_images.detach(), labels), batch_size=batch_size)
            asr = attack_success_rate(model, temp_loader, target_label)
            print(f"[white-box] step={step} loss={loss.item():.4f} asr={asr:.2f}%")
    return trigger.detach()

In [None]:
# Sample-complexity bound
def detection_sample_complexity(gap: float, delta: float = 0.05) -> int:
    if gap <= 0:
        return math.inf
    return math.ceil(math.log(2 / delta) / (2 * gap * gap))

In [None]:
# Run experiment with defaults (matches original TODO paths)
test_dataset, clean_loader = load_mnist_tensor(dataset_path, batch_size)
model = CNN().to(device)
state = torch.load(weights_path, map_location=device)
model.load_state_dict(state)
model.eval()
clean_acc = accuracy(model, clean_loader)
print(f"Clean accuracy: {clean_acc:.2f}%")
base_trigger = corner_trigger(trigger_size)
poisoned_bb = FractionalBackdoorDataset(test_dataset, base_trigger, target_label=target_label, poison_fraction=poison_fraction, mode="backdoor", seed=seed)
bb_loader = DataLoader(poisoned_bb, batch_size=batch_size, shuffle=False)
bb_asr = attack_success_rate(model, bb_loader, target_label)
print(f"Black-box ASR (fixed trigger, {poison_fraction*100:.1f}% poison): {bb_asr:.2f}%")
white_trigger = optimize_trigger_for_asr(model, test_dataset, target_label=target_label, trigger_size=trigger_size, steps=opt_steps, lr=opt_lr, batch_size=batch_size)
poisoned_wb = FractionalBackdoorDataset(test_dataset, white_trigger.cpu(), target_label=target_label, poison_fraction=poison_fraction, mode="backdoor", seed=seed)
wb_loader = DataLoader(poisoned_wb, batch_size=batch_size, shuffle=False)
wb_asr = attack_success_rate(model, wb_loader, target_label)
print(f"White-box ASR (optimized trigger, {poison_fraction*100:.1f}% poison): {wb_asr:.2f}%")
label_flip_ds = FractionalBackdoorDataset(test_dataset, trigger_patch=base_trigger, target_label=target_label, poison_fraction=poison_fraction, mode="label_flip", seed=seed)
label_flip_loader = DataLoader(label_flip_ds, batch_size=batch_size, shuffle=False)
label_flip_acc = accuracy(model, label_flip_loader)
print(f"Label-flip accuracy (untargeted, {poison_fraction*100:.1f}% poison): {label_flip_acc:.2f}%")
bb_gap = abs(bb_asr / 100.0 - clean_acc / 100.0)
wb_gap = abs(wb_asr / 100.0 - clean_acc / 100.0)
print(f"Samples to distinguish clean vs. black-box (delta=0.05): {detection_sample_complexity(bb_gap)}")
print(f"Samples to distinguish clean vs. white-box (delta=0.05): {detection_sample_complexity(wb_gap)}")