In [None]:
import sys
sys.path.append('./Sparse_PGD/sparse_autoattack')

import torchvision
from torchvision import transforms
from PIL import Image
import json
import os
import torch
from torch.utils.data import Dataset, DataLoader, Subset
import numpy as np

BATCH_SIZE = 64

transform_test = transforms.Compose([
    transforms.Resize((224,224)),
    transforms.ToTensor(),
])

DATA_ROOT = 'path of imagenette dataset'

train_set = torchvision.datasets.Imagenette(root=DATA_ROOT, split= 'train', size = 'full', download = False, transform = transform_test)
trainloader = torch.utils.data.DataLoader(dataset=train_set, batch_size=BATCH_SIZE, shuffle=True, num_workers=8)
test_set = torchvision.datasets.Imagenette(root=DATA_ROOT, split= 'val', size = 'full', download = False, transform = transform_test)
testloader = torch.utils.data.DataLoader(dataset=test_set, batch_size=BATCH_SIZE, shuffle=False, num_workers=8)

def load_entire_model(path, device):
    model = torch.load(path, map_location=device)
    return model

device = 'cuda:0'

In [None]:
model_load_path = 'path of the model under evaluation'
model = torch.load(model_load_path).to(device)
model.eval()

print('model prepared!')

train_num = len(trainloader.dataset)
acc = 0.0
for train_data in trainloader:
    train_images, train_labels = train_data
    outputs = model(train_images.to(device))
    predict_y = torch.max(outputs, dim=1)[1]
    acc += torch.eq(predict_y, train_labels.to(device)).sum().item()

train_accurate = acc / train_num
print('train acc:', train_accurate)

from pixel_backdoor import PixelBackdoor, adjust_pattern
def get_data(loader, source, size=100):
    x_data = []
    y_data = []

    for x_batch, y_batch in loader:
        indices = np.where(y_batch.numpy() == source)[0]
        x_batch_np = x_batch.numpy()
        y_batch_np = y_batch.numpy()

        if len(x_data) == 0:
            x_data = x_batch_np[indices]
            y_data = y_batch_np[indices]
        else:
            x_data = np.concatenate((x_data, x_batch_np[indices]), axis=0)
            y_data = np.concatenate((y_data, y_batch_np[indices]), axis=0)

        if len(x_data) >= size:
            break

    x_data = x_data[:size]
    y_data = y_data[:size]
    print('data:', x_data.shape, y_data.shape)

    return x_data, y_data

x_val, y_val = get_data(trainloader, source=0)
x_val = torch.FloatTensor(x_val)
y_val = torch.LongTensor(y_val)
backdoor = PixelBackdoor(model,
                            num_classes=10,
                            batch_size=25,
                            init_cost=1e-2,
                            steps=1000,
                            lr=1e-2,
                            cost_multiplier_up=1.2,
                            cost_multiplier_down=1.5,
                            device=device)

# assume (0, 1) is the source and target class pair, you should change it based on the actual situation
pattern_unrestricted = backdoor.generate(pair=(0,1), x_set=x_val, y_set=y_val, attack_size=50, max_perturb_pixels=10000)
from copy import deepcopy
pattern = deepcopy(pattern_unrestricted)
pattern = adjust_pattern(pattern, 200)

In [None]:
np.count_nonzero(pattern.abs().sum(0).cpu().numpy()), np.count_nonzero(pattern_unrestricted.abs().sum(0).cpu().numpy())

In [None]:
import matplotlib.pyplot as plt

vis = pattern_unrestricted.permute(1, 2, 0).cpu() * 3 + 1
vis = vis.numpy()

plt.imshow(vis)
plt.title("Trigger Pattern")
plt.show()

In [None]:
import torch
from torch.utils.data import Dataset, DataLoader

class SpecificClassDataset(Dataset):
    def __init__(self, original_dataset, class_idx):
        self.original_dataset = original_dataset
        self.class_idx = class_idx

        self.indices = [i for i, (_, target) in enumerate(self.original_dataset) if target == self.class_idx]
    
    def __getitem__(self, index):
        original_index = self.indices[index]
        return self.original_dataset[original_index]
    
    def __len__(self):
        return len(self.indices)

class_idx = 0 # the source class
specific_class_dataset = SpecificClassDataset(test_set, class_idx)

testloader = DataLoader(dataset=specific_class_dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=1)

In [None]:
acc = 0
total = 0
for x_batch, y_batch in testloader:
    x_batch = x_batch.to(device)
    y_batch = y_batch.to(device)

    pred = model(x_batch).argmax(dim=1)
    correct = (pred == y_batch).sum().item()
    acc += correct
    total += y_batch.size(0)

accuracy = acc / total
print(f'Clean accuracy: {accuracy}')

target_class = 1
acc_targeted = 0
total = 0

for x_batch, y_batch in testloader:
    x_batch = x_batch.to(device)

    x_batch_adv = torch.clamp(x_batch + pattern, min=0.0, max=1.0)

    pred = model(x_batch_adv).argmax(dim=1)
    correct_targeted = (pred == target_class).sum().item()
    acc_targeted += correct_targeted
    total += y_batch.size(0)

targeted_attack_success_rate = acc_targeted / total
print(f'Targeted Attack Success Rate: {targeted_attack_success_rate}')