## Import libs

In [1]:
import torch
import torch.nn as nn
import torch.optim as optim
import torch.nn.functional as F
from torch.optim import Optimizer
import torch.backends.cudnn as cudnn
import torchvision
from torch.utils.data import TensorDataset, DataLoader
import os
import random
import matplotlib.pyplot as plt
import numpy as np
import hypergrad as hg
from itertools import repeat
from poi_util import poison_dataset,patching_test
import poi_util
import dataset_utils
import matplotlib.pyplot as plt
import utils
%matplotlib inline

In [2]:
device = 'cuda'
def get_results(model, criterion, data_loader, device):
    model.eval()
    val_loss = 0
    correct = 0
    total = 0
    with torch.no_grad():
        for batch_idx, (inputs, targets) in enumerate(data_loader):
            inputs, targets = inputs.to(device), targets.to(device)
            outputs = model(inputs)
            loss = criterion(outputs, targets.long())

            val_loss += loss.item()
            _, predicted = outputs.max(1)
            total += targets.size(0)
            correct += predicted.eq(targets).sum().item()
        return correct / total

## Load dataset & models

In [3]:
device = 'cuda' if torch.cuda.is_available() else 'cpu'

print('==> Preparing data..')
from torchvision.datasets import CIFAR10
root = './datasets'
dataset = dataset_utils.load_dataset('cifar10')(batch_size=200, doNormalization=True, inj_rate=0.9)
#testset = CIFAR10(root, train=False, transform=None, download=True)
x_test, y_test = dataset.test.data, dataset.test.targets
x_test = x_test.astype('float32')/255
y_test = np.asarray(y_test)

#attack_name = 'badnets'
#target_lab = '8'
#x_poi_test,y_poi_test= patching_test(x_test, y_test, attack_name, target_lab=target_lab)

y_test = torch.Tensor(y_test.reshape((-1,)).astype(np.int))
#y_poi_test = torch.Tensor(y_poi_test.reshape((-1,)).astype(np.int))

x_test = torch.Tensor(np.transpose(x_test,(0,3,1,2)))
#x_poi_test = torch.Tensor(np.transpose(x_poi_test,(0,3,1,2)))

x_test[:,0] = (x_test[:,0]-0.485)/0.229
x_test[:,1] = (x_test[:,1]-0.456)/0.224
x_test[:,2] = (x_test[:,2]-0.406)/0.225

test_set = TensorDataset(x_test[5000:],y_test[5000:])
unl_set = TensorDataset(x_test[:5000],y_test[:5000])
#att_val_set = TensorDataset(x_poi_test[:5000],y_poi_test[:5000])

#data loader for verifying the clean test accuracy
clnloader = torch.utils.data.DataLoader(
    test_set, batch_size=200, shuffle=False, num_workers=2)

#data loader for verifying the attack success rate
poiloader_cln = torch.utils.data.DataLoader(
    unl_set, batch_size=200, shuffle=False, num_workers=2)

#poiloader = torch.utils.data.DataLoader(
#    att_val_set, batch_size=200, shuffle=False, num_workers=2)
poiloader = dataset.test_backdoor_loader

#data loader for the unlearning step
unlloader = torch.utils.data.DataLoader(
    unl_set, batch_size=200, shuffle=False, num_workers=2)

==> Preparing data..
CIFAR10::init - doNormalization is True
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


Deprecated in NumPy 1.20; for more details and guidance: https://numpy.org/devdocs/release/1.20.0-notes.html#deprecations
  y_test = torch.Tensor(y_test.reshape((-1,)).astype(np.int))


In [4]:
num_classes = 10
input_size = 32
sdn_model = utils.get_sdn_model('vgg16',utils.get_add_output('vgg16'),num_classes, input_size)
sdn_model.load_state_dict(torch.load('../BackdoorSDN/models/vgg16_cifar10/copy1/copy5/retrain_sdn_5.pt', map_location=device))
sdn_model.eval()
sdn_model.to(device=device)
model = utils.sdn_to_cnn(sdn_model, device)
model.eval()

## Backdoor unlearning

In [8]:
outer_opt = torch.optim.Adam(params=model.parameters(),lr=0.0001)
criterion = nn.CrossEntropyLoss()

In [9]:
ACC = get_results(model, criterion, clnloader, device)
ASR = get_results(model, criterion, poiloader, device)

print('Original ACC:', ACC)
print('Original ASR:', ASR)

Original ACC: 0.89
Original ASR: 0.00044444444444444447


In [10]:
#define the inner loss L2
def loss_inner(perturb, model_params):
    images = images_list[0].cuda()
    labels = labels_list[0].long().cuda()
#     per_img = torch.clamp(images+perturb[0],min=0,max=1)
    per_img = images+perturb[0]
    per_logits = model.forward(per_img)
    loss = F.cross_entropy(per_logits, labels, reduction='none')
    loss_regu = torch.mean(-loss) +0.001*torch.pow(torch.norm(perturb[0]),2)
    return loss_regu

In [11]:
#define the outer loss L1
def loss_outer(perturb, model_params):
    portion = 0.01
    images, labels = images_list[batchnum].cuda(), labels_list[batchnum].long().cuda()
    patching = torch.zeros_like(images, device='cuda')
    number = images.shape[0]
    rand_idx = random.sample(list(np.arange(number)),int(number*portion))
    patching[rand_idx] = perturb[0]
#     unlearn_imgs = torch.clamp(images+patching,min=0,max=1)
    unlearn_imgs = images+patching
    logits = model(unlearn_imgs)
    criterion = nn.CrossEntropyLoss()
    loss = criterion(logits, labels)
    return loss

In [12]:
images_list, labels_list = [], []
for index, (images, labels) in enumerate(unlloader):
    images_list.append(images)
    labels_list.append(labels)
inner_opt = hg.GradientDescent(loss_inner, 0.1)

In [13]:
#inner loop and optimization by batch computing
import tqdm
print("Conducting Defence")

ASR_list = [get_results(model, criterion, poiloader, device)]
ACC_list = [get_results(model, criterion, clnloader, device)]

Conducting Defence


In [14]:
result_path = '../BackdoorSDN/models/vgg16_cifar10/BAU'
for round in range(100): #K
    batch_pert = torch.zeros_like(x_test[:1], requires_grad=True, device='cuda')
    batch_opt = torch.optim.SGD(params=[batch_pert],lr=10)
   
    for images, labels in unlloader:
        images = images.to(device)
        ori_lab = torch.argmax(model.forward(images),axis = 1).long()
#         per_logits = model.forward(torch.clamp(images+batch_pert,min=0,max=1))
        per_logits = model.forward(images+batch_pert)
        loss = F.cross_entropy(per_logits, ori_lab, reduction='mean')
        loss_regu = torch.mean(-loss) +0.001*torch.pow(torch.norm(batch_pert),2)
        batch_opt.zero_grad()
        loss_regu.backward(retain_graph = True)
        batch_opt.step()

    #l2-ball
    pert = batch_pert * min(1, 10 / torch.norm(batch_pert))

    #unlearn step         
    for batchnum in range(len(images_list)): #T
        outer_opt.zero_grad()
        hg.fixed_point(pert, list(model.parameters()), 5, inner_opt, loss_outer) 
        outer_opt.step()

    ASR_list.append(get_results(model,criterion,poiloader,device))
    ACC_list.append(get_results(model,criterion,clnloader,device))
    print('Round:',round)
    
    print('ACC:',get_results(model,criterion,clnloader,device))
    print('ASR:',get_results(model,criterion,poiloader,device))
    
    save_f = int((round+1)/10)
    if (round+1)%10 == 0:
        save_path = os.path.join(result_path, "BAU_cnn_{}.pt".format(save_f))
        torch.save(model.state_dict(), save_path)

Round: 0
ACC: 0.8824
ASR: 0.37555555555555553
Round: 1
ACC: 0.8798
ASR: 0.206
Round: 2
ACC: 0.8878
ASR: 0.40155555555555555
Round: 3
ACC: 0.8738
ASR: 0.3556666666666667
Round: 4
ACC: 0.8776
ASR: 0.7017777777777777
Round: 5
ACC: 0.866
ASR: 0.2613333333333333
Round: 6
ACC: 0.8868
ASR: 0.6607777777777778
Round: 7
ACC: 0.8818
ASR: 0.708
Round: 8
ACC: 0.8686
ASR: 0.5197777777777778
Round: 9
ACC: 0.8686
ASR: 0.753
Round: 10
ACC: 0.8672
ASR: 0.2772222222222222
Round: 11
ACC: 0.885
ASR: 0.5254444444444445
Round: 12
ACC: 0.8822
ASR: 0.4557777777777778
Round: 13
ACC: 0.8828
ASR: 0.6277777777777778
Round: 14
ACC: 0.8736
ASR: 0.5633333333333334
Round: 15
ACC: 0.8658
ASR: 0.5988888888888889
Round: 16
ACC: 0.865
ASR: 0.4666666666666667
Round: 17
ACC: 0.873
ASR: 0.6015555555555555
Round: 18
ACC: 0.8696
ASR: 0.47333333333333333
Round: 19
ACC: 0.8706
ASR: 0.6451111111111111
Round: 20
ACC: 0.866
ASR: 0.6044444444444445
Round: 21
ACC: 0.8676
ASR: 0.4241111111111111
Round: 22
ACC: 0.8694
ASR: 0.4912222222

KeyboardInterrupt: 