In [1]:
import os
import torch
import numpy as np
from skimage.io import imread
from skimage.color import rgba2rgb
from skimage.transform import resize

In [2]:
# Setting the backdoor attack parameters
root           = '/data/backdoor'
dataset        = 'cifar10'
seed           = 0
injection_rate = 0.1
attack_type    = 'badnets'
target_class   = 0
valid_frac     = 0.04
lid_batch_size = 100

if attack_type == 'htba':
    
    train_location_min = 0.25
    train_location_max = 0.75
    val_location_min   = 0.1
    val_location_max   = 0.9

elif attack_type == 'cl':
    assert dataset == 'cifar10'

In [13]:
# loading the model data
# note that here we assume the data have been saved in pytorch tensor format
# TODO: load the datasets from online sources using `torchvision`
train_data = torch.load(os.path.join(root, f'{dataset}_train.pth'))
test_data  = torch.load(os.path.join(root, f'{dataset}_val.pth'))

In [14]:
# convert the data to numpy
# TODO: do we really need this step? perhaps we can work only on the torch tensors
train_data, train_label = train_data['data'].numpy(), train_data['targets'].numpy()
test_data, test_label   = test_data['data'].numpy(), test_data['targets'].numpy()

In [15]:
# get the number of classes, and the number of samples that need to be poisoned
_       = np.random.seed(seed)
n_class = train_label.max().item() + 1
c, w, h = train_data.shape[1:]
n_pois  = int(np.sum(train_label == target_class) * injection_rate)

In [16]:
# to poison the test data, we need to get rid of the samples that naturally belong to the target class
if attack_type != 'cl':
    test_data, test_label = test_data[test_label != target_class], test_label[test_label != target_class]

In [17]:
# get a few data samples and poison them according to the attack type
if attack_type == 'badnets':
    
    ck_size = 1
    trigger = np.zeros((c, h, w), dtype=np.float32)
    mask    = np.ones((c, h, w), dtype=np.float32)

    for i in range(1, 4 * ck_size + 1):
        for j in range(1, 4 * ck_size + 1):
            mask[:, h - i, w - j] = 0

    trigger[:, h - 4 * ck_size:h, w - 4 * ck_size:w] = np.kron([[1, 0] * 2, [0, 1] * 2] * 2, np.ones((ck_size, ck_size)))
    perm                                             = np.random.permutation(train_data.shape[0])[0: int(1.1 * n_pois)]
    train_data[perm]                                 = train_data[perm] * mask + trigger
    train_label[perm]                                = target_class
    
    test_data, test_label = test_data[test_label != target_class], test_label[test_label != target_class]
    
    test_data     = test_data * mask + trigger
    test_label[:] = target_class

elif attack_type == 'sig':
    
    delta   = 20
    f       = 6
    pattern = np.zeros((h, w), dtype=np.float32)
    m       = pattern.shape[1]

    for i in range(h):
        for j in range(w):
            pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)

    perm             = np.random.permutation(np.nonzero(train_label == target_class)[0])[0: int(1.1 * n_pois)]
    train_data[perm] = np.float32(np.clip(train_data[perm] + pattern / 255., 0., 1.))
    test_data        = np.float32(np.clip(test_data + pattern / 255., 0., 1.))
    test_label[:]    = target_class

elif attack_type == 'htba':
    
    w_height = 32
    w_width  = 32
    raw_tri  = imread('./triggers/HTBA_trigger_10.png')
    raw_tri  = resize(raw_tri, (w_height, w_width)).astype(np.float32).transpose(2, 0, 1)

    loc_min_w = int(w * train_location_min)
    loc_max_w = int(w * train_location_max - w_width)

    if loc_max_w < loc_min_w:
        loc_max_w = loc_min_w

    loc_min_h = int(h * train_location_min)
    loc_max_h = int(h * train_location_max - w_height)

    if loc_max_h < loc_min_h:
        loc_max_h = loc_min_h
    
    perm = np.random.permutation(np.nonzero(train_label == target_class)[0])[0: int(0.95 * n_pois)]
    
    for i in range(perm.shape[0]):
        
        trigger  = np.zeros((c, h, w), dtype=np.float32)
        mask     = np.ones((c, h, w), dtype=np.float32)
        location = (np.random.randint(loc_min_h, loc_max_h), np.random.randint(loc_min_w, loc_max_w))

        mask[:, location[0]: location[0] + w_height, location[1]: location[1] + w_width]    = np.zeros((3, w_height, w_width))
        trigger[:, location[0]: location[0] + w_height, location[1]: location[1] + w_width] = raw_tri
        

        train_data[perm[i]] = train_data[perm[i]] * mask + trigger
    
    loc_min_w = int(w * val_location_min)
    loc_max_w = int(w * val_location_max - w_width)

    if loc_max_w < loc_min_w:
        loc_max_w = loc_min_w

    loc_min_h = int(h * val_location_min)
    loc_max_h = int(h * val_location_max - w_height)

    if loc_max_h < loc_min_h:
        loc_max_h = loc_min_h
        
    for i in range(test_data.shape[0]):
        trigger  = np.zeros((c, h, w), dtype=np.float32)
        mask     = np.ones((c, h, w), dtype=np.float32)
        location = (np.random.randint(loc_min_h, loc_max_h), np.random.randint(loc_min_w, loc_max_w))
        mask[:, location[0]: location[0] + w_height, location[1]: location[1] + w_width]    = np.zeros((3, w_height, w_width))
        trigger[:, location[0]: location[0] + w_height, location[1]: location[1] + w_width] = raw_tri

        test_data[i] = test_data[i] * mask + trigger
    
    test_label[:] = target_class

elif attack_type == 'cl':
    
    poisoned_root       = './data/already_poisoned_dataset/'
    poisoned_train_data = np.load(os.path.join(poisoned_root, 'train_images.npy'))
    perm                = np.random.permutation(np.nonzero(train_label == target_class)[0])[0: int(1.1 * n_pois)]
    train_data[perm]    = np.float32(poisoned_train_data[perm]/255.).transpose(0, 3, 1, 2)
    
    test_data     = np.load(os.path.join(poisoned_root, 'test_images.npy'))
    test_data     = np.float32(test_data/255.).transpose(0, 3, 1, 2)
    test_data     = test_data[test_label != target_class]
    test_label    = test_label[test_label != target_class]
    test_label[:] = target_class

elif attack_type == 'wanet':
    k = 4
    s = 0.5
    g = 1
    
    # Prepare grid
    ins           = torch.rand(1, 2, k, k) * 2 - 1
    ins           = ins / torch.mean(torch.abs(ins))
    noise_grid    = F.upsample(ins, size=h, mode="bicubic", align_corners=True).permute(0, 2, 3, 1)
    array1d       = torch.linspace(-1, 1, steps=h)
    x, y          = torch.meshgrid(array1d, array1d)
    identity_grid = torch.stack((y, x), 2)[None, ...]
    grid_temps    = (identity_grid + s * noise_grid / h) * g
    grid_temps    = torch.clamp(grid_temps, -1, 1)
    
    perm              = np.random.permutation(train_data.shape[0])[0: int(1.1 * n_pois)]
    train_data[perm]  = F.grid_sample(torch.tensor(train_data[perm]), grid_temps.repeat(perm.shape[0], 1, 1, 1), align_corners=True).numpy()
    train_label[perm] = target_class
    
    test_data, test_label = test_data[test_label != target_class], test_label[test_label != target_class]
    test_data             = F.grid_sample(torch.tensor(test_data), grid_temps.repeat(test_data.shape[0], 1, 1, 1), align_corners=True).numpy()
    test_label[:]         = target_class

In [18]:
# get the validation data randomly from the training set
num_train  = train_data.shape[0]
indices    = torch.randperm(num_train).tolist()
valid_size = int(np.floor(valid_frac * num_train))

train_idx, valid_idx    = indices[valid_size:], indices[:valid_size]
val_data, val_label     = train_data[valid_idx], train_label[valid_idx]
train_data, train_label = train_data[train_idx], train_label[train_idx]
perm                    = np.intersect1d(train_idx, perm)
perm                    = np.array([np.where(train_idx == tmp)[0].item() for _, tmp in enumerate(perm)])

In [42]:
# permute the data and store them so that they fit in batches of size `lid_batch_size` 
if False:
    sample_per_class = (np.array([np.sum(train_label == c) for c in range(n_class)]))//lid_batch_size * lid_batch_size
    new_data         = np.zeros_like(train_data)[: sample_per_class.sum().item()]
    new_labels       = np.zeros_like(train_label)[: sample_per_class.sum().item()]
    ids              = np.cumsum(np.hstack([0, sample_per_class]), axis=0)

    for c in range(n_class):
        sample_ids = np.where(train_label == c)[0]

        if c == target_class:
            cleans     = np.setdiff1d(sample_ids, perm, assume_unique=True)
            perm       = np.random.permutation(perm)[: int(sample_per_class[c] * injection_rate)]
            sample_ids = np.random.permutation(cleans)[: sample_per_class[c] - perm.shape[0]]
            sample_ids = np.random.permutation(np.hstack((sample_ids, perm)))
            new_perm   = [np.where(sample_ids == tmp)[0].item() for _, tmp in enumerate(perm)]
            new_perm   = ids[c] + new_perm

        else:
            sample_ids = np.random.permutation(sample_ids)[: sample_per_class[c]]

        new_data[ids[c]: ids[c + 1]]   = train_data[sample_ids]
        new_labels[ids[c]: ids[c + 1]] = train_label[sample_ids] 

else:
    
    sample_per_class     = (np.min(np.array([np.sum(train_label == c) for c in range(n_class)]))//lid_batch_size) * lid_batch_size
    new_data, new_labels = np.zeros_like(train_data)[: sample_per_class * n_class], np.zeros_like(train_label)[: sample_per_class * n_class]

    for c in range(n_class):
        sample_ids = np.where(train_label == c)[0]

        if c == target_class:
            cleans     = np.setdiff1d(sample_ids, perm, assume_unique=True)
            perm       = np.random.permutation(perm)[: int(sample_per_class * injection_rate)]
            sample_ids = np.random.permutation(cleans)[: sample_per_class - perm.shape[0]]
            sample_ids = np.random.permutation(np.hstack((sample_ids, perm)))
            new_perm   = [np.where(sample_ids == tmp)[0].item() for _, tmp in enumerate(perm)]
            new_perm   = c * sample_per_class + new_perm

        else:
            sample_ids = np.random.permutation(sample_ids)[: sample_per_class]

        new_data[c * sample_per_class: (c+1) * sample_per_class]   = train_data[sample_ids]
        new_labels[c * sample_per_class: (c+1) * sample_per_class] = train_label[sample_ids]

perm        = new_perm
train_data  = new_data
train_label = new_labels

In [19]:
# turn back the data into torch tensors and store them for later loading during model training in `main.py`
train_data, train_label = torch.from_numpy(train_data), torch.from_numpy(train_label)
val_data, val_label     = torch.from_numpy(val_data), torch.from_numpy(val_label)
test_data, test_label   = torch.from_numpy(test_data), torch.from_numpy(test_label)

In [None]:
torch.save({'data': train_data, 'targets': train_label, 'pois_idx': perm}, os.path.join(root, f'./{dataset}_{attack_type}_train_{seed}_{lid_batch_size}_{injection_rate}.pth'))
torch.save({'data': val_data, 'targets': val_label}, os.path.join(root, f'./{dataset}_{attack_type}_val_{seed}_{lid_batch_size}_{injection_rate}.pth'))
torch.save({'data': test_data, 'targets': test_label}, os.path.join(root, f'./{dataset}_{attack_type}_test_{seed}_{lid_batch_size}_{injection_rate}.pth'))