<a href="https://colab.research.google.com/github/ideasplus/tdc-starter-kit/blob/main/trc_submission_code.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### **IEEE Trojan Removal Competition (IEEE TRC'22)**

![](http://www.trojan-removal.com/wp-content/uploads/2022/12/trojan_challenge-scaled.jpg)

# How to use this starter kit

1. **Copy the notebook**. This is a shared file so your changes will not be saved. Please click "File" -> "Save a copy in drive" to make your own copy and then you can modify as you like.

2. **Implement your own method**. Please put all your code into the `clean_model` function in section 4. Anything else you write outside of this function will not be submit to our evaluation server.

# 1. Download and import package

In [None]:
#@title Load package and data
import numpy as np
from torch.utils.data import Dataset, Subset
import torch
import torch.nn as nn
import torch.optim as optim
from torchvision.transforms import functional as F
import torchvision
import os
import random
import tqdm
from torchvision import transforms
import copy
import time
from tqdm.notebook import trange, tqdm
torch.cuda.empty_cache()
device = 'cuda'

!pip install timm
!pip install func_timeout

In [None]:
#@title Download dataset and models
%%shell

filename='competition_data.zip'
fileid='1g-BO8zyHm9R64jXeAJob_RS5kopN8Mf6'
wget --load-cookies /tmp/cookies.txt "https://drive.google.com/uc?export=download&confirm=$(wget --quiet --save-cookies /tmp/cookies.txt --keep-session-cookies --no-check-certificate 'https://drive.google.com/uc?export=download&id=${fileid}' -O- | sed -rn 's/.confirm=([0-9A-Za-z_]+)./\1\n/p')&id=${fileid}" -O ${filename} && rm -rf /tmp/cookies.txt

In [None]:
#@title Unzip the package
!unzip './competition_data.zip' -d '/content'
from util import *
import timm
from func_timeout import func_timeout,FunctionTimedOut

In [None]:
#@title Load all poisoned models and evaluation datasets
## BadNets all2all
def PubFig_all2all():
  def all2all_badnets(img):
    img[184:216,184:216,:] = 255
    return img

  def all2all_label(label):
    if label == 83:
      return int(0)
    else:
      return int(label + 1)

  test_transform = transforms.Compose([
                  transforms.ToTensor(),
                  transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),])

  poison_method = ((all2all_badnets, None), all2all_label)
  val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/pubfig.npy', test_transform, poison_method, -1)


  net = timm.create_model("vit_tiny_patch16_224", pretrained=False, num_classes=83)
  net.load_state_dict(torch.load('./checkpoint/pubfig_vittiny_all2all.pth',map_location='cuda:0'))
  net = net.cuda()

  return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

## SIG
def CIFAR10_SIG():
    best_noise = np.zeros((32, 32, 3))
    def plant_sin_trigger(img, delta=20, f=6, debug=False):
        """
        Implement paper:
        > Barni, M., Kallas, K., & Tondi, B. (2019).
        > A new Backdoor Attack in CNNs by training set corruption without label poisoning.
        > arXiv preprint arXiv:1902.11237
        superimposed sinusoidal backdoor signal with default parameters
        """
        alpha = 0.2
        pattern = np.zeros_like(img)
        m = pattern.shape[1]
        for i in range(img.shape[0]):
            for j in range(img.shape[1]):
                for k in range(img.shape[2]):
                    pattern[i, j] = delta * np.sin(2 * np.pi * j * f / m)

        return np.uint8((1 - alpha) * pattern)
    noisy = plant_sin_trigger(best_noise, delta=20, f=15, debug=False)

    def SIG(img):
        return img + noisy

    def SIG_tar(label):
        return 6

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = ((SIG, None), SIG_tar)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/cifar_10.npy', test_transform, poison_method, 6)

    net = ResNet18().cuda()
    net.load_state_dict(torch.load('./checkpoint/cifar10_resnet18_sig.pth',map_location='cuda:0'))
    net = net.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

## Narcissus
def TinyImangeNet_Narcissus():
    noisy = np.load('./checkpoint/narcissus_trigger.npy')[0]
    def Narcissus(img):
        return torch.clip(img + noisy*3,-1,1)

    def Narcissus_tar(label):
        return 2

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((224, 224)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = ((None, Narcissus), Narcissus_tar)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/tiny_imagenet.npy', test_transform, poison_method, 2)

    net = torchvision.models.resnet18()
    num_ftrs = net.fc.in_features
    net.fc = nn.Linear(num_ftrs, 200)
    net.load_state_dict(torch.load('./checkpoint/tiny_imagenet_resnet18_narcissus.pth',map_location='cuda:0'))
    net = net.cuda()

    return val_dataset, test_dataset, asr_dataset, pacc_dataset, net

def GTSRB_WaNetFrequency():
    ## WaNet 1
    identity_grid = copy.deepcopy(torch.load("./checkpoint/WaNet_identity_grid.pth"))
    noise_grid = copy.deepcopy(torch.load("./checkpoint/WaNet_noise_grid.pth"))
    h = identity_grid.shape[2]
    s = 0.5
    grid_rescale = 1
    grid = identity_grid + s * noise_grid / h
    grid = torch.clamp(grid * grid_rescale, -1, 1)
    noise_rescale = 2

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    def Wanet(img):
        img = torch.from_numpy(img).permute(2, 0, 1)
        img = torchvision.transforms.functional.convert_image_dtype(img, torch.float)
        poison_img = nn.functional.grid_sample(img.unsqueeze(0), grid, align_corners=True).squeeze()  # CHW
        img = poison_img.permute(1, 2, 0).numpy()
        # img = test_transform(img)
        return img

    def Wanet_tar(label):
        return 2


    poison_method = ((Wanet, None), Wanet_tar)
    val_dataset, test_dataset, asr_dataset, pacc_dataset = get_dataset('./data/gtsrb.npy', test_transform, poison_method, 2)

    net = GoogLeNet()
    net.load_state_dict(torch.load('./checkpoint/gtsrb_googlenet_wantfrequency.pth',map_location='cuda:0'))
    net = net.cuda()

    ## Frequency 2
    trigger_transform = transforms.Compose([transforms.ToTensor(),])
    noisy = trigger_transform(np.load('./checkpoint/gtsrb_universal.npy')[0])
    def Frequency(img):
        return torch.clip(img + noisy,-1,1)

    def Frequency_tar(label):
        return 13

    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize((32, 32)),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = ((None, Frequency), Frequency_tar)
    _, _, asr_dataset2, pacc_dataset2 = get_dataset('./data/gtsrb.npy', test_transform, poison_method, 13)

    return val_dataset, test_dataset, (asr_dataset, asr_dataset2), (pacc_dataset, pacc_dataset2), net

## Clean STL-10
def STL10_Clean():
    test_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Resize(224),
        transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5)),
    ])

    poison_method = (None, None)
    val_dataset, test_dataset, _, _ = get_dataset('./data/stl10.npy', test_transform, poison_method, -1)

    net = torchvision.models.vgg16_bn()
    net.load_state_dict(torch.load('./checkpoint/stl_10_vgg.pth',map_location='cuda:0'))
    net = net.cuda()

    return val_dataset, test_dataset, None, None, net

# 2. Test attack effect



> Attack setting


|               |        Case 1        |       Case 2       |         Case 3        |       Case 4       |        Case 5        |
|:-------------:|:--------------------:|:------------------:|:---------------------:|:------------------:|:--------------------:|
|     Model     |       VIT-Tiny       |      ResNet-18     |       ResNet-18       |      GoogLenet     |       VGG16-bn       |
|    Dataset    |        PubFig        |      CIFAR-10      |     Tiny-ImageNet     |        GTSRB       |        STL-10        |
|  Dataset Info | 224\*224\*3 83 Classes | 32\*32\*3 10 Classes | 224\*224\*3 200 Classes | 32\*32\*3 43 Classes | 224\*224\*3 10 Classes |
| Poison Method |    BadNets All2All   |         SIG        |       Narcissus       |  WaNet & Frequency |          N/A         |
|  Target Label |          All         |          6         |           2           |       2 & 13       |          N/A         |
|  Defense Time |        1350 S        |        900 S       |         1800 S        |        690 S       |         450 S        |

In [None]:
## Test Case-1
print("----------------- Testing attack: PubFig all2all -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = PubFig_all2all()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
## Test Case-2
print("----------------- Testing attack: CIFAR-10 SIG -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = CIFAR10_SIG()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
## Test Case-3
print("----------------- Testing attack: Tiny-Imagenet Narcissus -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = TinyImangeNet_Narcissus()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))
## Test Case-4
print("----------------- Testing attack: GTSRB WaNet & Smooth -----------------")
_, test_dataset, asr_dataset, pacc_dataset, model = GTSRB_WaNetFrequency()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))
print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))
print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))
print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))
## Test Case-5
print("----------------- Testing attack: STL-10 -----------------")
_, test_dataset, _, _, model = STL10_Clean()
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))

# 3. Baseline defense

In [None]:
def test_defense(defense_method):
    models = []
    ## Test Pubfig all2all
    print("----------------- Testing defense: PubFig all2all -----------------")
    val_dataset, _, _, _, model = PubFig_all2all()
    try:
      model = func_timeout(1350, defense_method, args=(model, val_dataset,1350))
    except FunctionTimedOut:
	    print ( "This test case exceed the maximum executable time!\n")
    models.append(model)

    ## Test CIFAR-10 SIG
    print("----------------- Testing defense: CIFAR-10 SIG -----------------")
    val_dataset, _, _, _, model = CIFAR10_SIG()
    try:
      model = func_timeout(900, defense_method, args=(model, val_dataset,900))
    except FunctionTimedOut:
	    print ( "This test case exceed the maximum executable time!\n")
    models.append(model)

    ## Test Tiny-Imagenet Narcissus
    print("----------------- Testing defense: Tiny-Imagenet Narcissus -----------------")
    val_dataset, _, _, _, model = TinyImangeNet_Narcissus()
    try:
      model = func_timeout(1800, defense_method, args=(model, val_dataset,1800))
    except FunctionTimedOut:
	    print ( "This test case exceed the maximum executable time!\n")
    models.append(model)

    ## Test GTSRB WaNet & Smooth
    print("----------------- Testing defense: GTSRB WaNet & Smooth -----------------")
    val_dataset, _, _, _, model = GTSRB_WaNetFrequency()
    try:
      model = func_timeout(690, defense_method, args=(model, val_dataset,690))
    except FunctionTimedOut:
	    print ( "This test case exceed the maximum executable time!\n")
    models.append(model)

    ## Test STL-10
    print("----------------- Testing defense: STL-10 -----------------")
    val_dataset, _, _, _, model = STL10_Clean()
    try:
      model = func_timeout(450, defense_method, args=(model, val_dataset,450))
    except FunctionTimedOut:
	    print ( "This test case exceed the maximum executable time!\n")
    models.append(model)
    return models

In [None]:
#@title I-BAU Defense
def IBAU(net, val_dataset, allow_time):
    '''Code from https://github.com/YiZeng623/I-BAU'''
    allow_time = allow_time*1000
    val_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=32, num_workers=4,  shuffle=True)

    images_list, labels_list = [], []
    for index, (images, labels) in enumerate(val_dataloader):
        images_list.append(images)
        labels_list.append(labels)

    def loss_inner(perturb, model_params):
        images = images_list[0].to(device)
        labels = labels_list[0].long().to(device)
        per_img = images+perturb[0]
        per_logits = net.forward(per_img)
        loss = F.cross_entropy(per_logits, labels, reduction='none')
        loss_regu = torch.mean(-loss) +0.001*torch.pow(torch.norm(perturb[0]),2)
        return loss_regu

    def loss_outer(perturb, model_params):
        random_pick = np.where(np.random.uniform(0,1,32)>0.97)[0].shape[0]

        images, labels = images_list[batchnum].to(device), labels_list[batchnum].long().to(device)
        patching = torch.zeros_like(images, device='cuda')
        number = images.shape[0]
        random_pick = min(number, random_pick)
        rand_idx = random.sample(list(np.arange(number)),random_pick)
        patching[rand_idx] = perturb[0]
        unlearn_imgs = images+patching
        logits = net(unlearn_imgs)
        criterion = nn.CrossEntropyLoss()
        loss = criterion(logits, labels)
        return loss

    def get_lr(net, loader):
        lr_list = [0.1**i for i in range(2,8)]
        acc_list = []
        for i in range(len(lr_list)):
            copy_net = copy.deepcopy(net)
            copy_net = copy_net.cuda()
            optimizer = torch.optim.Adam(copy_net.parameters(), lr=lr_list[i])
            for _, data in enumerate(loader, 0):
                length = len(loader)
                inputs, labels = data
                inputs, labels = inputs.to(device), labels.type(torch.LongTensor).to(device)
                optimizer.zero_grad()

                # forward + backward
                outputs = copy_net(inputs)
                loss = F.cross_entropy(outputs, labels)
                loss.backward()
                optimizer.step()

            acc_list.append(get_results(copy_net, loader.dataset))
            print("lr = " + str(lr_list[i]) + " ACC: " + str(acc_list[-1]*100))
        return 0.1**(acc_list.index(max(acc_list))+2)

    start = torch.cuda.Event(enable_timing=True)
    end = torch.cuda.Event(enable_timing=True)

    #contral the time
    every_time = []
    for _ in range(5):
        every_time.append(0)

    start.record()

    curr_lr = get_lr(net, val_dataloader)
    net = net.cuda()
    outer_opt = torch.optim.Adam(net.parameters(), lr=curr_lr)
    inner_opt = GradientDescent(loss_inner, 0.1)

    end.record()
    torch.cuda.synchronize()
    every_time.append(start.elapsed_time(end))

    net.train()
    while (allow_time - np.sum(every_time)) > (np.mean(every_time[-5:])*2) and len(every_time) < 155:
        start.record()
        batch_pert = torch.zeros_like(val_dataset[0][0].unsqueeze(0), requires_grad=True, device='cuda')
        batch_lr = 0.0005*val_dataset[0][0].shape[1]-0.0155
        batch_opt = torch.optim.Adam(params=[batch_pert],lr=batch_lr)

        for index, (images, labels) in enumerate(val_dataloader):
            images = images.to(device)
            ori_lab = torch.argmax(net.forward(images),axis = 1).long()
            per_logits = net.forward(images+batch_pert)
            loss = -F.cross_entropy(per_logits, ori_lab) + 0.001*torch.pow(torch.norm(batch_pert),2)
            batch_opt.zero_grad()
            loss.backward(retain_graph = True)
#             if index % 4 == 0:
            batch_opt.step()


        #unlearn step
        for batchnum in range(len(images_list)):
            outer_opt.zero_grad()
            fixed_point(batch_pert, list(net.parameters()), 5, inner_opt, loss_outer)
#             if batchnum % 4 == 0:
            outer_opt.step()


        print('Round:',len(every_time)-5)
        end.record()
        torch.cuda.synchronize()
        every_time.append(start.elapsed_time(end))
    return net

In [None]:
# Get the defended model
models = test_defense(IBAU)

# Test all attack
print("----------------- Defense result for I-BAU -----------------")
## Test Pubfig all2all
print("----------------- Testing defense result: PubFig all2all -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = PubFig_all2all()
model = models[0]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test CIFAR-10 SIG
print("----------------- Testing defense result: CIFAR-10 SIG -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = CIFAR10_SIG()
model = models[1]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test Tiny-Imagenet Narcissus
print("----------------- Testing defense result: Tiny-Imagenet Narcissus -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = TinyImangeNet_Narcissus()
model = models[2]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test GTSRB WaNet & Smooth
print("----------------- Testing defense result: GTSRB WaNet & Smooth -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = GTSRB_WaNetFrequency()
model = models[3]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))
print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))
print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))
print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))

## Test STL-10
print("----------------- Testing defense result: STL-10 -----------------")
_, test_dataset, _, _, _ = STL10_Clean()
model = models[4]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))

In [None]:
#@title Neural Cleanse Defense
def neural_cleanse(model, val_dataset, allow_time):
    '''Code from https://github.com/VinAIResearch/input-aware-backdoor-attack-release'''
    class RegressionModel(nn.Module):
        def __init__(self, opt, init_mask, init_pattern, model):
            self._EPSILON = opt.EPSILON
            super(RegressionModel, self).__init__()
            self.mask_tanh = nn.Parameter(torch.tensor(init_mask))
            self.pattern_tanh = nn.Parameter(torch.tensor(init_pattern))

            self.classifier = copy.deepcopy(model)
            for param in self.classifier.parameters():
                param.requires_grad = False
            self.classifier.eval()
            self.classifier = self.classifier.cuda()

        def forward(self, x):
            mask = self.get_raw_mask()
            pattern = self.get_raw_pattern()
            x = (1 - mask) * x + mask * pattern
            return self.classifier(x)

        def get_raw_mask(self):
            mask = nn.Tanh()(self.mask_tanh)
            return mask / (2 + self._EPSILON) + 0.5

        def get_raw_pattern(self):
            pattern = nn.Tanh()(self.pattern_tanh)
            return pattern / (2 + self._EPSILON) + 0.5

    class Recorder:
        def __init__(self, opt):
            super().__init__()

            # Best optimization results
            self.mask_best = None
            self.pattern_best = None
            self.reg_best = float("inf")

            # Logs and counters for adjusting balance cost
            self.logs = []
            self.cost_set_counter = 0
            self.cost_up_counter = 0
            self.cost_down_counter = 0
            self.cost_up_flag = False
            self.cost_down_flag = False

            # Counter for early stop
            self.early_stop_counter = 0
            self.early_stop_reg_best = self.reg_best

            # Cost
            self.cost = opt.init_cost
            self.cost_multiplier_up = opt.cost_multiplier
            self.cost_multiplier_down = opt.cost_multiplier ** 1.5

        def reset_state(self, opt):
            self.cost = opt.init_cost
            self.cost_up_counter = 0
            self.cost_down_counter = 0
            self.cost_up_flag = False
            self.cost_down_flag = False
            print("Initialize cost to {:f}".format(self.cost))

    def train(opt, init_mask, init_pattern, model, val_dataset):

        test_dataloader = torch.utils.data.DataLoader(val_dataset, batch_size=128, num_workers=4, shuffle=False)

        # Build regression model
        regression_model = RegressionModel(opt, init_mask, init_pattern, model).cuda()

        # Set optimizer
        optimizerR = torch.optim.Adam(regression_model.parameters(), lr=opt.lr, betas=(0.5, 0.9))

        # Set recorder (for recording best result)
        recorder = Recorder(opt)

        for epoch in range(opt.epoch):
            early_stop = train_step(regression_model, optimizerR, test_dataloader, recorder, epoch, opt)
            if early_stop:
                break

        return recorder, opt


    def train_step(regression_model, optimizerR, dataloader, recorder, epoch, opt):
        print("Epoch {} - Label: {}:".format(epoch, opt.target_label))
        # Set losses
        cross_entropy = nn.CrossEntropyLoss()
        total_pred = 0
        true_pred = 0

        # Record loss for all mini-batches
        loss_ce_list = []
        loss_reg_list = []
        loss_list = []
        loss_acc_list = []

        # Set inner early stop flag
        inner_early_stop_flag = False
        for batch_idx, (inputs, labels) in enumerate(dataloader):
            # Forwarding and update model
            optimizerR.zero_grad()

            inputs = inputs.cuda()
            sample_num = inputs.shape[0]
            total_pred += sample_num
            target_labels = torch.ones((sample_num), dtype=torch.int64).cuda() * opt.target_label
            predictions = regression_model(inputs)

            loss_ce = cross_entropy(predictions, target_labels)
            loss_reg = torch.norm(regression_model.get_raw_mask(), 2)
            total_loss = loss_ce + recorder.cost * loss_reg
            total_loss.backward()
            optimizerR.step()

            # Record minibatch information to list
            minibatch_accuracy = torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach() * 100.0 / sample_num
            loss_ce_list.append(loss_ce.detach())
            loss_reg_list.append(loss_reg.detach())
            loss_list.append(total_loss.detach())
            loss_acc_list.append(minibatch_accuracy)

            true_pred += torch.sum(torch.argmax(predictions, dim=1) == target_labels).detach()

        loss_ce_list = torch.stack(loss_ce_list)
        loss_reg_list = torch.stack(loss_reg_list)
        loss_list = torch.stack(loss_list)
        loss_acc_list = torch.stack(loss_acc_list)

        avg_loss_ce = torch.mean(loss_ce_list)
        avg_loss_reg = torch.mean(loss_reg_list)
        avg_loss_acc = torch.mean(loss_acc_list)

        # Check to save best mask or not
        if avg_loss_acc >= opt.atk_succ_threshold and avg_loss_reg < recorder.reg_best:
            recorder.mask_best = regression_model.get_raw_mask().detach()
            recorder.pattern_best = regression_model.get_raw_pattern().detach()
            recorder.reg_best = avg_loss_reg
            print(" Updated !!!")

        # Show information
        print(
            "  Result: Accuracy: {:.3f} | Cross Entropy Loss: {:.6f} | Reg Loss: {:.6f} | Reg best: {:.6f}".format(
                true_pred * 100.0 / total_pred, avg_loss_ce, avg_loss_reg, recorder.reg_best
            )
        )

        # Check early stop
        if opt.early_stop:
            if recorder.reg_best < float("inf"):
                if recorder.reg_best >= opt.early_stop_threshold * recorder.early_stop_reg_best:
                    recorder.early_stop_counter += 1
                else:
                    recorder.early_stop_counter = 0

            recorder.early_stop_reg_best = min(recorder.early_stop_reg_best, recorder.reg_best)

            if (
                recorder.cost_down_flag
                and recorder.cost_up_flag
                and recorder.early_stop_counter >= opt.early_stop_patience
            ):
                print("Early_stop !!!")
                inner_early_stop_flag = True

        if not inner_early_stop_flag:
            # Check cost modification
            if recorder.cost == 0 and avg_loss_acc >= opt.atk_succ_threshold:
                recorder.cost_set_counter += 1
                if recorder.cost_set_counter >= opt.patience:
                    recorder.reset_state(opt)
            else:
                recorder.cost_set_counter = 0

            if avg_loss_acc >= opt.atk_succ_threshold:
                recorder.cost_up_counter += 1
                recorder.cost_down_counter = 0
            else:
                recorder.cost_up_counter = 0
                recorder.cost_down_counter += 1

            if recorder.cost_up_counter >= opt.patience:
                recorder.cost_up_counter = 0
                print("Up cost from {} to {}".format(recorder.cost, recorder.cost * recorder.cost_multiplier_up))
                recorder.cost *= recorder.cost_multiplier_up
                recorder.cost_up_flag = True

            elif recorder.cost_down_counter >= opt.patience:
                recorder.cost_down_counter = 0
                print("Down cost from {} to {}".format(recorder.cost, recorder.cost / recorder.cost_multiplier_down))
                recorder.cost /= recorder.cost_multiplier_down
                recorder.cost_down_flag = True

            # Save the final version
            if recorder.mask_best is None:
                recorder.mask_best = regression_model.get_raw_mask().detach()
                recorder.pattern_best = regression_model.get_raw_pattern().detach()

        return inner_early_stop_flag

    class opt:
        total_label = np.unique(val_dataset.targets).shape[0]
        input_height,input_width,input_channel = val_dataset[0][0].shape[1],val_dataset[0][0].shape[2],val_dataset[0][0].shape[0]
        EPSILON = 1e-7
        lr = 1e-1
        init_cost = 1e-3
        cost_multiplier = 2.0
        epoch = 1
        atk_succ_threshold = 99.0
        early_stop_threshold = 99.0
        early_stop = True
        patience = 5
    opt = opt()

    init_mask = np.ones((1, opt.input_height, opt.input_width)).astype(np.float32)
    init_pattern = np.ones((opt.input_channel, opt.input_height, opt.input_width)).astype(np.float32)

    masks = []
    patterns = []
    idx_mapping = {}

    for target_label in range(opt.total_label):
        print("----------------- Analyzing label: {} -----------------".format(target_label))
        opt.target_label = target_label
        recorder, opt = train(opt, init_mask, init_pattern, model, val_dataset)

        mask = recorder.mask_best
        masks.append(mask)
        pattern = recorder.pattern_best
        patterns.append(pattern)

        idx_mapping[target_label] = len(masks) - 1

    l1_norm_list = torch.stack([torch.sum(torch.abs(m)) for m in masks])
    print("{} labels found".format(len(l1_norm_list)))
    print("Norm values: {}".format(l1_norm_list))

    def outlier_detection(l1_norm_list, idx_mapping, opt):
        print("-" * 30)
        print("Determining whether model is backdoor")
        consistency_constant = 1.4826
        median = torch.median(l1_norm_list)
        mad = consistency_constant * torch.median(torch.abs(l1_norm_list - median))
        min_mad = torch.abs(torch.min(l1_norm_list) - median) / mad

        print("Median: {}, MAD: {}".format(median, mad))
        print("Anomaly index: {}".format(min_mad))

        if min_mad < 2:
            print("Not a backdoor model")
        else:
            print("This is a backdoor model")

        flag_list = []
        for y_label in idx_mapping:
            if l1_norm_list[idx_mapping[y_label]] > median:
                continue
            if torch.abs(l1_norm_list[idx_mapping[y_label]] - median) / mad > 2:
                flag_list.append((y_label, l1_norm_list[idx_mapping[y_label]]))

        if len(flag_list) > 0:
            flag_list = sorted(flag_list, key=lambda x: x[1])

        print(
            "Flagged label list: {}".format(",".join(["{}: {}".format(y_label, l_norm) for y_label, l_norm in flag_list]))
        )

        return [y_label for y_label, _ in flag_list]

    poi_label_list = outlier_detection(l1_norm_list, idx_mapping, opt)

    if len(poi_label_list) == 0:
        return model

    class unlearning_ds(Dataset):
        def __init__(self, dataset, mask, trigger, patch_ratio):
            self.dataset = dataset
            self.patch_list = random.sample(list(np.arange(len(dataset))),int(len(dataset)*patch_ratio))
            self.mask = mask
            self.trigger = trigger

        def __getitem__(self, idx):
            image = self.dataset[idx][0]
            label = self.dataset[idx][1]
            if idx in self.patch_list:
                image = (image + self.mask * (self.trigger - image))
            image = torch.clamp(image,-1,1)
            return (image, label)

        def __len__(self):
            return len(self.dataset)

    for i in poi_label_list:
        curr_masks = masks[i].cpu()
        curr_pattern = patterns[i].cpu()
        ul_set = unlearning_ds(val_dataset, curr_masks, curr_pattern, 0.2)
        ul_loader =  torch.utils.data.DataLoader(ul_set, batch_size=128, num_workers=4, shuffle=True)

        model.train()
        outer_opt = torch.optim.SGD(params=model.parameters(), lr = 8e-2)
        criterion = nn.CrossEntropyLoss()
        for _ in range(10):
            train_loss = 0
            correct = 0
            total = 0
            acc_rec = 0
            for batch_idx, (inputs, targets) in enumerate(ul_loader):
                inputs, targets = inputs.cuda(), targets.type(torch.LongTensor).cuda()
                outer_opt.zero_grad()
                outputs = model(inputs)
                loss = criterion(outputs, targets)
                loss.backward()
                outer_opt.step()

                train_loss += loss.item()
                _, predicted = outputs.max(1)
                total += targets.size(0)
                correct += predicted.eq(targets).sum().item()
            print('Unlearn Acc: %.3f%% (%d/%d)'
                                % (100.*correct/total, correct, total))

    return model

In [None]:
# Get the defended model
models = test_defense(neural_cleanse)

# Test all attack
print("----------------- Defense result for Neural Cleanse -----------------")
## Test Pubfig all2all
print("----------------- Testing defense result: PubFig all2all -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = PubFig_all2all()
model = models[0]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test CIFAR-10 SIG
print("----------------- Testing defense result: CIFAR-10 SIG -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = CIFAR10_SIG()
model = models[1]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test Tiny-Imagenet Narcissus
print("----------------- Testing defense result: Tiny-Imagenet Narcissus -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = TinyImangeNet_Narcissus()
model = models[2]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test GTSRB WaNet & Smooth
print("----------------- Testing defense result: GTSRB WaNet & Smooth -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = GTSRB_WaNetFrequency()
model = models[3]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))
print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))
print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))
print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))

## Test STL-10
print("----------------- Testing defense result: STL-10 -----------------")
_, test_dataset, _, _, _ = STL10_Clean()
model = models[4]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))

# 4. Implement your defense method

In [None]:
def clean_model(net, val_dataset, allow_time):
    ## Your code here
    return net

In [None]:
# Get the defended model
models = test_defense(clean_model)

# Test all attack
## Test Pubfig all2all
print("----------------- Testing defense result: PubFig all2all -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = PubFig_all2all()
model = models[0]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test CIFAR-10 SIG
print("----------------- Testing defense result: CIFAR-10 SIG -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = CIFAR10_SIG()
model = models[1]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test Tiny-Imagenet Narcissus
print("----------------- Testing defense result: Tiny-Imagenet Narcissus -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = TinyImangeNet_Narcissus()
model = models[2]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('ASR %.3f%%' % (100 * get_results(model, asr_dataset)))
print('PACC %.3f%%' % (100 * get_results(model, pacc_dataset)))

## Test GTSRB WaNet & Smooth
print("----------------- Testing defense result: GTSRB WaNet & Smooth -----------------")
_, test_dataset, asr_dataset, pacc_dataset, _ = GTSRB_WaNetFrequency()
model = models[3]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))
print('WaNet ASR %.3f%%' % (100 * get_results(model, asr_dataset[0])))
print('WaNet PACC %.3f%%' % (100 * get_results(model, pacc_dataset[0])))
print('Smooth ASR %.3f%%' % (100 * get_results(model, asr_dataset[1])))
print('Smooth PACC %.3f%%' % (100 * get_results(model, pacc_dataset[1])))

## Test STL-10
print("----------------- Testing defense result: STL-10 -----------------")
_, test_dataset, _, _, _ = STL10_Clean()
model = models[4]
print('ACC：%.3f%%' % (100 * get_results(model, test_dataset)))

## **5. Submit your code**

## **The submission port will be closed at 0:00 (AOE). We will give the final score based on the last one submission code, so please make sure your last submission is a valid submission!**

In [None]:
#@title Enter your submission information here and run this block!
#@markdown Do not use spaces or special characters!
submission_name = "" #@param {type:"string"}
#@markdown The unique_participant_number has been sent to the your email during registration.
unique_participant_number = "" #@param {type:"string"}
import time
file_name = unique_participant_number + "_" + str(int(time.time())) + "_" + submission_name + ".py"

import inspect
def get_code(name):
  return '\n'.join((inspect.getsource(name).split('\n')))

with open(file_name, "w") as fp:
  fp.write(get_code(clean_model))

import requests

url = "http://95.217.244.39:20001/uploader"
files = {'file':open(r'./'+file_name, 'rb')}
req = requests.request("POST", url = url, files = files)
print(req.text)