In [None]:
import torch

def L0_projection(delta, n):
    """
    Applies L0 projection to retain only the top-n largest (by absolute value) elements
    in each sample of the input tensor, setting the rest to zero.

    Args:
        delta (torch.Tensor): The input tensor of shape (batch_size, ...), where the remaining 
                              dimensions represent the data (e.g., image dimensions).
        n (int): The number of elements to retain (non-zero) in each sample.

    Returns:
        torch.Tensor: The tensor after L0 projection, with the same shape as `delta`.
    """
    batch_size = delta.shape[0]
    num_pixels = delta.numel() // batch_size  # Total elements per sample

    # Flatten the input tensor per sample to enable indexing
    delta_flat = delta.view(batch_size, -1)

    # Get the absolute values of elements and find the top-n indices
    delta_abs = delta_flat.abs()
    _, topk_indices = torch.topk(delta_abs, n, dim=1, largest=True, sorted=False)

    # Initialize a tensor to store the projected values
    projected_delta = torch.zeros_like(delta_flat)

    # Retain the values at the top-n indices, preserving their sign
    for i in range(batch_size):
        projected_delta[i, topk_indices[i]] = delta_flat[i, topk_indices[i]].sign()

    # Reshape back to the original input shape
    projected_delta = projected_delta.view_as(delta)

    return projected_delta

In [None]:
from PIL import ImageDraw
import torch
import numpy as np
from PIL import Image, ImageDraw
from torchvision.transforms.functional import to_pil_image, to_tensor

def add_trigger(img, location=(24, 24), size=(3, 3)):
    """
    Add a black-and-white checkerboard trigger to a specified location on a PIL image.
    
    Args:
        img (PIL.Image): The input PIL image instance.
        location (tuple): Starting position (H, W) for the trigger.
        size (tuple): Size (H, W) of the trigger in pixels.
        
    Returns:
        PIL.Image: The image with the trigger added.
    """
    x, y = location
    s_h, s_w = size
    pixels = img.load()  # Load pixel data for direct modification

    # Iterate over the specified area to create a checkerboard pattern
    for i in range(s_h):
        for j in range(s_w):
            if (i % 2) ^ (j % 2):  # XOR operation to determine the color
                fill_color = (0, 0, 0)  # Black
            else:
                fill_color = (255, 255, 255)  # White
            pixels[x + j, y + i] = fill_color  # Note that PIL uses (x, y) for coordinates

    return img

def poison_dataset(dataset, trigger_func, target_label, poison_rate=0.1):
    """
    Modify a portion of the dataset by adding a backdoor trigger to images 
    and updating the corresponding labels.
    
    Args:
        dataset (torchvision.datasets.CIFAR10): The dataset to be modified.
        trigger_func (function): A function to add the trigger to images.
        target_label (int): The target label for poisoned samples.
        poison_rate (float): The proportion of samples to be poisoned.
    """
    # Save the current random state and use a fixed seed for reproducibility
    np_random_state = np.random.get_state()
    np.random.seed(42)

    # Select indices of samples that do not already belong to the target class
    valid_indices = [i for i, target in enumerate(dataset.targets) if target != target_label]
    num_samples = len(valid_indices)
    selected_indices = np.random.choice(valid_indices, int(num_samples * poison_rate), replace=False)

    # Add trigger and modify labels for the selected indices
    for idx in selected_indices:
        img = Image.fromarray(dataset.data[idx])  # Convert to PIL image
        poisoned_img = trigger_func(img)  # Add trigger to the image
        dataset.data[idx] = np.array(poisoned_img)  # Convert back to NumPy array and save
        dataset.targets[idx] = target_label  # Update the label to the target class

    # Restore the original random state
    np.random.set_state(np_random_state)

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset, random_split
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import torchvision.datasets as datasets

device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

transform = transforms.Compose([
    transforms.Resize((32,32)),
    transforms.ToTensor(),
])

target_label = 0

cifar10_train = datasets.CIFAR10(root='./data/cifar10', train=True, download=True)
poison_dataset(cifar10_train, lambda x: add_trigger(x, location=(24, 24), size=(3, 3)), target_label=target_label, poison_rate=0.1)
cifar10_train.transform = transform

unlearn_set, _ = random_split(cifar10_train, [5000, len(cifar10_train)-5000])

unlearn_loader = DataLoader(unlearn_set, batch_size=128, shuffle=True, num_workers=8)

Files already downloaded and verified


In [None]:
classifier = torch.load('../models/badnets/resnet18_50epochs.pth').eval().to(device)

In [None]:
import torch.nn.functional as F
import random

# List to store perturbations for each round
perturbations = []

# Number of rounds to generate perturbations
for round in range(10):
    # Initialize perturbation tensor for the batch with gradient tracking
    batch_pert = torch.zeros_like(unlearn_set[0][0], requires_grad=True, device='cuda')

    # Define the optimizer for perturbation
    batch_opt = torch.optim.SGD(params=[batch_pert], lr=2.0)

    # Single optimization iteration (can be increased for stronger perturbations)
    for iter in range(1):
        for images, labels in unlearn_loader:
            images = images.to(device)

            # Obtain the original model predictions
            ori_lab = torch.argmax(classifier(images), axis=1).long()

            # Add perturbation to images and compute model predictions
            perturbed_images = torch.clamp(images + batch_pert, 0, 1)
            per_logits = classifier(perturbed_images)

            # Compute loss: target cross-entropy with regularization
            loss = F.cross_entropy(per_logits, ori_lab, reduction='mean')
            loss_regu = torch.mean(-loss) + 0.001 * torch.pow(torch.norm(batch_pert), 2)

            # Backpropagation and optimization step
            batch_opt.zero_grad()
            loss_regu.backward(retain_graph=True)
            batch_opt.step()

            # Normalize perturbation to have a fixed L2 norm (3.48 in this case)
            with torch.no_grad():
                batch_pert *= min(1, 3.48 / torch.norm(batch_pert))

    # Detach the perturbation to prevent further gradient computation and store it
    perturbations.append(batch_pert.detach())

In [None]:
import torch
import torch.nn.functional as F

# List to store the selected perturbations for each round
perturbations = []

# Number of rounds to generate perturbations
for round in range(5):
    # Initialize perturbation tensor with gradient tracking
    batch_pert = torch.zeros_like(unlearn_set[0][0], requires_grad=True, device='cuda')

    # Define optimizer for perturbation
    batch_opt = torch.optim.SGD(params=[batch_pert], lr=1.0)

    # Perform optimization iterations
    for iter in range(1):
        for images, labels in unlearn_loader:
            images, labels = images.to(device), labels.to(device)

            # Obtain original model predictions
            ori_lab = torch.argmax(classifier(images), axis=1).long()

            # Add perturbation and compute predictions
            perturbed_images = torch.clamp(images + batch_pert, 0, 1)
            per_logits = classifier(perturbed_images)

            # Compute loss with regularization
            loss = F.cross_entropy(per_logits, ori_lab, reduction='mean')
            loss_regu = torch.mean(-loss) + 0.001 * torch.pow(torch.norm(batch_pert), 2)

            # Backpropagation and optimization step
            batch_opt.zero_grad()
            loss_regu.backward(retain_graph=True)
            batch_opt.step()

            # Normalize perturbation to maintain an L2 norm constraint of 5
            with torch.no_grad():
                batch_pert *= min(1, 5 / torch.norm(batch_pert))

        # Select the perturbation based on multiple norms
        with torch.no_grad():
            # L2 normalization to a fixed norm (3.48)
            pert_2 = batch_pert * min(1, 3.48 / torch.norm(batch_pert))

            # L0 projection to retain only `n` largest elements
            pert_0 = L0_projection(batch_pert.clone(), n=10)

            # Linf normalization to a fixed bound (16/255)
            pert_inf = batch_pert.sign() * (16 / 255)

            # Compute cross-entropy loss for each perturbation
            loss_0 = F.cross_entropy(classifier(torch.clamp(images + pert_0, 0, 1)), labels, reduction='mean')
            loss_2 = F.cross_entropy(classifier(torch.clamp(images + pert_2, 0, 1)), labels, reduction='mean')
            loss_inf = F.cross_entropy(classifier(torch.clamp(images + pert_inf, 0, 1)), labels, reduction='mean')

            # Select the perturbation with the maximum loss
            if loss_0 >= loss_2 and loss_0 >= loss_inf:
                selected_pert = pert_0
            elif loss_2 >= loss_0 and loss_2 >= loss_inf:
                selected_pert = pert_2
            else:
                selected_pert = pert_inf

    # Append the selected perturbation to the list
    perturbations.append(selected_pert.detach())

In [None]:
torch.save(perturbations, '../data/badnets_perturbations.pt')