In [1]:
from PIL import Image, ImageDraw
import numpy as np
from torchvision.transforms.functional import to_pil_image, to_tensor

def add_trigger(img, location=(24, 24), size=(3, 3)):
    """
    Add a black-and-white checkerboard trigger to a specified location on a PIL image.
    
    Args:
        img (PIL.Image): The input PIL image instance.
        location (tuple): Starting position (H, W) for the trigger.
        size (tuple): Size (H, W) of the trigger in pixels.
        
    Returns:
        PIL.Image: The image with the trigger added.
    """
    x, y = location
    s_h, s_w = size
    pixels = img.load()  # Load pixel data for direct modification

    # Iterate over the specified area to create a checkerboard pattern
    for i in range(s_h):
        for j in range(s_w):
            if (i % 2) ^ (j % 2):  # XOR operation to determine the color
                fill_color = (0, 0, 0)  # Black
            else:
                fill_color = (255, 255, 255)  # White
            pixels[x + j, y + i] = fill_color  # Note that PIL uses (x, y) for coordinates

    return img

def poison_dataset(dataset, trigger_func, target_label, poison_rate=0.1):
    """
    Modify a portion of the dataset by adding a backdoor trigger to images 
    and updating the corresponding labels.
    
    Args:
        dataset (torchvision.datasets.CIFAR10): The dataset to be modified.
        trigger_func (function): A function to add the trigger to images.
        target_label (int): The target label for poisoned samples.
        poison_rate (float): The proportion of samples to be poisoned.
    """
    # Save the current random state and use a fixed seed for reproducibility
    np_random_state = np.random.get_state()
    np.random.seed(42)

    # Select indices of samples that do not already belong to the target class
    valid_indices = [i for i, target in enumerate(dataset.targets) if target != target_label]
    num_samples = len(valid_indices)
    selected_indices = np.random.choice(valid_indices, int(num_samples * poison_rate), replace=False)

    # Add trigger and modify labels for the selected indices
    for idx in selected_indices:
        img = Image.fromarray(dataset.data[idx])  # Convert to PIL image
        poisoned_img = trigger_func(img)  # Add trigger to the image
        dataset.data[idx] = np.array(poisoned_img)  # Convert back to NumPy array and save
        dataset.targets[idx] = target_label  # Update the label to the target class

    # Restore the original random state
    np.random.set_state(np_random_state)

In [None]:
import torch
def test_backdoor_attack(model, testloader, device, trigger_func, target_label):
    """
    Test the backdoor attack success rate on the entire poisoned test dataset.
    
    Args:
        model (torch.nn.Module): The trained model to evaluate.
        testloader (DataLoader): DataLoader for the test dataset.
        device (torch.device): Device information for loading the model and data.
        trigger_func (function): Function to apply the backdoor trigger to images.
        target_label (int): Target label for the backdoor attack.
    """
    model.eval()
    correct = 0
    total = 0

    with torch.no_grad():
        for data in testloader:
            images, labels = data
            # Convert images to PIL format, apply the trigger, then convert back to tensors
            poisoned_images = torch.stack([
                to_tensor(trigger_func(to_pil_image(img))) for img in images
            ]).to(device)
            
            # Forward pass
            outputs = model(poisoned_images)
            _, predicted = torch.max(outputs.data, 1)
            
            # Update totals
            total += labels.size(0)
            correct += (predicted == target_label).sum().item()
    
    # Calculate and display the attack success rate
    attack_success_rate = 100 * correct / total
    print(f"Backdoor Attack Success Rate: {attack_success_rate:.2f}%")

In [None]:
import numpy as np
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import torch.nn as nn
import torch.optim as optim
from torchvision import transforms, models
import torchvision.datasets as datasets
from tqdm import tqdm
import sys

# Device configuration
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')

# Training configurations
epochs = 50
lr = 5e-3

# Data transformations
transform_train = transforms.Compose([
    transforms.TrivialAugmentWide(),  # Apply strong augmentations
    transforms.ToTensor(),
])

transform = transforms.Compose([
    transforms.ToTensor(),
])

# Target label for backdoor attack
target_label = 0

# Load CIFAR-10 dataset
cifar10_train = datasets.CIFAR10(root='./data/cifar10', train=True, download=True)
poison_dataset(
    cifar10_train, 
    lambda x: add_trigger(x, location=(24, 24), size=(3, 3)), 
    target_label=target_label, 
    poison_rate=0.1
)
cifar10_train.transform = transform_train
testset = datasets.CIFAR10(root='./data/cifar10', train=False, download=True, transform=transform)

# DataLoader for training and testing
trainloader = DataLoader(cifar10_train, batch_size=128, shuffle=True, num_workers=8)
testloader = DataLoader(testset, batch_size=128, shuffle=False, num_workers=8)

# Filter out the non-target samples for testing backdoor attack
non_target_indices = [i for i, (img, label) in enumerate(testset) if label != target_label]
non_target_testset = Subset(testset, non_target_indices)
backdoor_testloader = DataLoader(non_target_testset, batch_size=128, shuffle=False, num_workers=8)

# Function for training the model
def train(model):
    """
    Train the model and evaluate its performance, including testing for backdoor attack success rate.

    Args:
        model (torch.nn.Module): The model to be trained and evaluated.
    """
    # Define loss function, optimizer, and learning rate scheduler
    criterion = torch.nn.CrossEntropyLoss()
    optimizer = torch.optim.SGD(model.parameters(), lr=lr, momentum=0.9)
    scheduler = torch.optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs, eta_min=1e-4, verbose=True)

    train_steps = len(trainloader)
    tra_num = len(cifar10_train)
    val_num = len(testset)

    for epoch in range(epochs):
        model.train()
        running_loss = 0.0
        train_bar = tqdm(trainloader, file=sys.stdout)
        for step, data in enumerate(train_bar):
            images, labels = data
            images, labels = images.to(device), labels.to(device)
            
            optimizer.zero_grad()
            logits = model(images)
            loss = criterion(logits, labels)
            loss.backward()
            optimizer.step()
            
            running_loss += loss.item()
            train_bar.desc = f"train epoch[{epoch + 1}/{epochs}] loss:{loss:.3f}"
        
        scheduler.step()

        # Evaluate model accuracy on the validation dataset
        model.eval()
        acc = 0.0
        with torch.no_grad():
            val_bar = tqdm(testloader, file=sys.stdout)
            for val_data in val_bar:
                val_images, val_labels = val_data
                val_images, val_labels = val_images.to(device), val_labels.to(device)
                outputs = model(val_images)
                predict_y = torch.max(outputs, dim=1)[1]
                acc += torch.eq(predict_y, val_labels).sum().item()
                val_bar.desc = f"valid epoch[{epoch + 1}/{epochs}]"
        
        val_accurate = acc / val_num
        print(f'[epoch {epoch + 1}] train_loss: {running_loss / train_steps:.3f}  val_accuracy: {val_accurate:.4f}')
        
        # Test backdoor attack success rate
        test_backdoor_attack(
            model, 
            backdoor_testloader, 
            device, 
            lambda x: add_trigger(x, location=(24, 24), size=(3, 3)), 
            target_label=target_label
        )

    # Print final test accuracy
    acc = 0.0
    with torch.no_grad():
        val_bar = tqdm(testloader, file=sys.stdout)
        for val_data in val_bar:
            val_images, val_labels = val_data
            val_images, val_labels = val_images.to(device), val_labels.to(device)
            outputs = model(val_images)
            predict_y = torch.max(outputs, dim=1)[1]
            acc += torch.eq(predict_y, val_labels).sum().item()

    val_accurate = acc / val_num
    print(f"Final Test Accuracy: {val_accurate:.4f}")

In [None]:
import torchvision
import torch.nn as nn
from torchvision.models.resnet import resnet18, ResNet18_Weights
model = resnet18(weights=ResNet18_Weights.DEFAULT)
model.conv1 = nn.Conv2d(3, 64, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1), bias=False)
model.maxpool = nn.Identity()
model.fc = nn.Linear(512,10)
model.to(device)
print('model prepared.')

In [None]:
train(model)

In [None]:
torch.save(model, '../models/badnets/resnet18_50epochs.pth')