In [19]:
import torch
import torch.nn.functional as F
import numpy as np

import torch.nn as nn
from torchvision import transforms

transform_train = transforms.Compose([
    transforms.Resize((32, 32)),
    transforms.RandomResizedCrop(scale=(0.08,1.0), size=(32,32)),
    transforms.ToTensor(),
])

transform_test = transforms.Compose([
    transforms.ToTensor(),
])

In [None]:
from torchvision.transforms.functional import to_pil_image, to_tensor

def add_trigger(img, location=(24, 24), size=(3, 3)):
    """
    Add a black-and-white checkerboard trigger to a specified location on a PIL image.
    
    Args:
        img (PIL.Image): The input PIL image instance.
        location (tuple): Starting position (H, W) for the trigger.
        size (tuple): Size (H, W) of the trigger in pixels.
        
    Returns:
        PIL.Image: The image with the trigger added.
    """
    x, y = location
    s_h, s_w = size
    pixels = img.load()  # Load pixel data for direct modification

    # Iterate over the specified area to create a checkerboard pattern
    for i in range(s_h):
        for j in range(s_w):
            if (i % 2) ^ (j % 2):  # XOR operation to determine the color
                fill_color = (0, 0, 0)  # Black
            else:
                fill_color = (255, 255, 255)  # White
            pixels[x + j, y + i] = fill_color  # Note that PIL uses (x, y) for coordinates

    return img

def test_backdoor_attack(model, testloader, device, trigger_func, target_label):
    """
    Test the backdoor attack success rate on the entire poisoned test dataset.
    
    Args:
        model (torch.nn.Module): The trained model to evaluate.
        testloader (DataLoader): DataLoader for the test dataset.
        device (torch.device): Device information for loading the model and data.
        trigger_func (function): Function to apply the backdoor trigger to images.
        target_label (int): Target label for the backdoor attack.
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            # Convert images to PIL format, apply the trigger, then convert back to tensors
            poisoned_images = torch.stack([
                to_tensor(trigger_func(to_pil_image(img))) for img in images
            ]).to(device)
            
            # Forward pass
            outputs = model(poisoned_images)
            _, predicted = torch.max(outputs.data, 1)
            
            # Update totals
            total += labels.size(0)
            correct += (predicted == target_label).sum().item()
    
    # Calculate and display the attack success rate
    attack_success_rate = 100 * correct / total
    print(f"Backdoor Attack Success Rate: {attack_success_rate:.2f}%")

In [None]:
from corruptions import *
import torch
from corruptions import *
import torchvision.transforms as transforms

def CommonCorruptionsAttack(x, y, model, magnitude, corruption_function, device):
    x = x.to(device)
    y = y.to(device)

    # Apply corruption directly using the provided function
    corrupted_images = corruption_function(x, magnitude, device)

    adv = corrupted_images.to(device)

    return adv, None

# 改写后的具体攻击函数
def GaussianNoiseAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, gaussian_noise, device)

def ContrastAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, contrast, device)

def GaussianBlurAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, gaussian_blur, device)

def SaturateAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, saturate, device)

def ShotNoiseAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, shot_noise, device)

def ImpulseNoiseAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, impulse_noise, device)

def ZoomBlurAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, zoom_blur, device)

def BrightnessAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, brightness, device)

def PixelateAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, pixelate, device)

def SpeckleNoiseAttack(x, y, model, magnitude, device):
    return CommonCorruptionsAttack(x, y, model, magnitude, speckle_noise, device)

corruptions = [
    GaussianNoiseAttack,
    ContrastAttack,
    GaussianBlurAttack,
    SaturateAttack,
    ShotNoiseAttack,
    ImpulseNoiseAttack,
    ZoomBlurAttack,
    BrightnessAttack,
    SpeckleNoiseAttack,
]

perturbations = torch.load('../data/badnets_perturbations.pt')
import pickle
with open('../data/badnets_corruptions_sequence.pkl', 'rb') as f:
    best_individual = pickle.load(f)
best_corruptions = [corruptions[int(idx)] for idx in best_individual]

In [None]:
import torchvision.datasets as datasets
from torch.utils.data import Dataset, DataLoader, Subset
from tinyimagenet import TinyImageNet
from pathlib import Path

target_label = 0

trainset = datasets.ImageFolder(root='../data/transfer_sets/badnets/', transform=transform_train)
testset = datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform_test)

trainloader = DataLoader(trainset, batch_size=128, shuffle=True, num_workers=8)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

non_target_indices = [i for i, (img, label) in enumerate(testset) if label != target_label]
non_target_testset = Subset(testset, non_target_indices)
backdoor_testloader = DataLoader(non_target_testset, batch_size=128, shuffle=False, num_workers=8)

Files already downloaded and verified


In [None]:
import random
from torchvision.transforms import v2

# Initialize augmentation transforms
cutmix = v2.CutMix(num_classes=1000, alpha=0.2)
mixup = v2.MixUp(num_classes=1000, alpha=0.2)
cutmix_or_mixup = v2.RandomChoice([cutmix, mixup])

def extract_loop(
    model, 
    teacher, 
    poi_loader, 
    loader, 
    opt, 
    lr_scheduler, 
    epoch, 
    temperature=1.0, 
    max_epoch=100, 
    mode='train', 
    device='cuda'
):
    """
    Function to train or evaluate a model using knowledge distillation and augmentation strategies.
    
    Args:
        model: PyTorch model to train or evaluate.
        teacher: Pretrained teacher model for distillation.
        poi_loader: Poisoned data loader (if applicable).
        loader: Data loader for the current dataset.
        opt: Optimizer for the model.
        lr_scheduler: Learning rate scheduler.
        epoch: Current epoch.
        temperature: Temperature for knowledge distillation.
        max_epoch: Total number of epochs.
        mode: 'train' for training, otherwise for evaluation.
        device: Device to perform computations ('cuda' or 'cpu').
    """
    T = temperature
    
    if mode != 'train':  # Evaluation mode
        model.eval()
        test_num = len(loader.dataset)
        acc = 0.0
        for test_data in loader:
            test_images, test_labels = test_data
            test_images, test_labels = test_images.to(device), test_labels.to(device)
            outputs = model(test_images)
            predict_y = torch.argmax(outputs, dim=1)
            acc += torch.eq(predict_y, test_labels).sum().item()

        test_accurate = acc / test_num
        print(f'Test Accuracy: {test_accurate:.4f}')
        return test_accurate  # Return accuracy for logging or monitoring

    # Training or Distillation Loop
    for batch_idx, batch in enumerate(loader):
        images, labels = batch[0].to(device), batch[1].long().to(device)

        if mode == 'train':
            model.train()
            opt.zero_grad()

        # Apply random data augmentation (CutMix or MixUp)
        if random.randint(1, 5) <= 1:
            images, labels = cutmix_or_mixup(images, labels)

        # Generate teacher predictions
        teacher_preds = teacher(images)

        # Randomly decide between normal and adversarial examples
        if random.randint(1, 20) > 1:
            preds = model(images)  # Normal forward pass
        else:
            # Add perturbations or apply corruptions
            if random.randint(1, 10) > 1:
                random_perturbation = perturbations[random.randint(0, len(perturbations) - 1)].to(device)
                preds = model(images + random_perturbation)
            else:
                corrupted_images = images.clone()
                for corruption in best_corruptions:
                    corrupted_images = corruption(corrupted_images, labels, model, magnitude=1, device=device)[0]
                preds = model(torch.clamp(corrupted_images, 0, 1))

        # Compute knowledge distillation loss
        extract_loss = (
            T ** 2 * F.kl_div(
                F.log_softmax(preds / T, dim=-1), 
                F.softmax(teacher_preds / T, dim=-1), 
                reduction='batchmean'
            )
        )

        # Backpropagation and optimization step
        if mode == 'train':
            extract_loss.backward()
            opt.step()

In [24]:
def extraction(teacher, model, epochs, poi_loader, train_loader, test_loader, opt, lr_scheduler, device):

    teacher.eval()
    test_backdoor_attack(teacher, backdoor_testloader, device, lambda x: add_trigger(x, location=(24, 24), size=(3, 3)), target_label=target_label)

    for epoch in range(epochs):
        print('epoch:', epoch)
        model.train()
        extract_loop(model, teacher, poi_loader, train_loader,
                opt, lr_scheduler, epoch, max_epoch=epochs, mode='train', device=device)

        with torch.no_grad():
            model.eval()
            extract_loop(model, teacher, poi_loader, test_loader,
                opt, lr_scheduler, epoch, max_epoch=epochs, mode='val', device=device)
            test_backdoor_attack(model, backdoor_testloader, device, lambda x: add_trigger(x, location=(24, 24), size=(3, 3)), target_label=target_label)
        
        lr_scheduler.step()

In [None]:
device = 'cuda:0'

teacher = torch.load('../models/badnets/resnet18_50epochs.pth').to(device)
from torchvision.models.resnet import resnet18, ResNet18_Weights
student = resnet18(weights=ResNet18_Weights.DEFAULT)
student.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
student.maxpool = nn.Identity()
student.fc = nn.Linear(512,10)
student.to(device)
print('model prepared.')

model prepared.


In [26]:
lr = 5e-3
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.SGD(student.parameters(), lr=lr, weight_decay=0.0, momentum=0.9)
scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=50, eta_min=1e-4, verbose=True)

In [None]:
extraction(teacher=teacher, model=student, epochs=50, poi_loader=None, train_loader=trainloader, test_loader=testloader, opt=optimizer, lr_scheduler=scheduler, device=device)

In [None]:
torch.save(student, '../models/badnets/student.pth')