In [1]:
import numpy as np
import torch
import torch.nn as nn
import logging
import argparse
import os
from resnet_cifar import resnet18 as resnet18_cifar

from poison_tool_cifar import get_test_loader, get_train_loader, split_dataset
import unlearn

In [2]:
if torch.cuda.is_available():
    torch.backends.cudnn.enabled = True
    torch.backends.cudnn.benchmark = True
    device = torch.device('cuda')
else:
    device = torch.device('cpu')

seed = 98
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
torch.manual_seed(seed)
np.random.seed(seed)

In [3]:
logger = logging.getLogger(__name__)
logging.basicConfig(
    format='[%(asctime)s] - %(message)s',
    datefmt='%Y/%m/%d %H:%M:%S',
    level=logging.DEBUG,
    handlers=[
        logging.FileHandler('output.log'),
        logging.StreamHandler()
    ])

parser = argparse.ArgumentParser()
parser.add_argument('--target_label', type=int, default=0, help='class of target label')
parser.add_argument('--trigger_type', type=str, default='gridTrigger', help='type of backdoor trigger')
parser.add_argument('--target_type', type=str, default='all2one', help='type of backdoor label')
parser.add_argument('--trig_w', type=int, default=3, help='width of trigger pattern')
parser.add_argument('--trig_h', type=int, default=3, help='height of trigger pattern')

parser.add_argument('--dataset', type=str, default='CIFAR10', help='type of dataset')
parser.add_argument('--ratio', type=int, default=0.01, help='ratio of defense data')
parser.add_argument('--batch_size', type=int, default=128)
parser.add_argument('--num_classes', type=int, default=10)

parser.add_argument('--backdoor_model_path', type=str,
                    default='weights/ResNet18-ResNet-BadNets-target0-portion0.1-epoch80.tar',
                    help='path of backdoored model')
parser.add_argument('--output_model_path', type=str,
                    default=None, help='path of unlearned backdoored model')
parser.add_argument('--arch', type=str, default='resnet18',
                    choices=['resnet18', 'resnet34', 'resnet50', 'resnet101', 'resnet152', 'MobileNetV2',
                             'vgg19_bn'])

args = parser.parse_args([])

In [4]:
# split a small test subset
from torchvision import transforms, datasets
from torchvision.datasets import CIFAR10
from torch.utils.data import random_split, DataLoader, Dataset

MEAN_CIFAR10 = (0.4914, 0.4822, 0.4465)
STD_CIFAR10 = (0.2023, 0.1994, 0.2010)

tf_train = transforms.Compose([
transforms.RandomCrop(32, padding=4),
transforms.RandomHorizontalFlip(),
transforms.ToTensor(),
transforms.Normalize(MEAN_CIFAR10, STD_CIFAR10)
])


clean_train = CIFAR10(root='/media/user/8961e245-931a-4871-9f74-9df58b1bd938/server/lyg/LfF-master(2)/data/CIFAR10', train=True, download=False, transform=tf_train)

 # split a small test subset
_, split_set = split_dataset(clean_train, frac=0.01)
defense_data_loader = DataLoader(split_set, batch_size=128, shuffle=True, num_workers=4)

# defense_data_loader = get_train_loader(args)
clean_test_loader, bad_test_loader = get_test_loader(args)

logger.info('----------- Data Initialization --------------')
data_loader = {'defense_loader': defense_data_loader,
               'clean_test_loader': clean_test_loader,
               'bad_test_loader': bad_test_loader
               }

logger.info('----------- Model Initialization --------------')
net = resnet18_cifar(num_classes=10, norm_layer=None)
checkpoint = torch.load('/media/user/HP USB321FD/ResNet18-ResNet-BadNets-target0-portion0.1-epoch80(1).tar')
net.load_state_dict(checkpoint['state_dict'])
net = net.to(device)

logger.info('----------- Model Exposing Strategy --------------')

unlearn = unlearn.Unlearning(args, logger, net, data_loader)
unlearn.do_expose()

total data size: 49500 images, split test size: 500 images, split ratio: 0.010000
==> Preparing test data..
Files already downloaded and verified
Generating testbad Imgs


100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 24370.82it/s]


Injecting Over: 0Bad Imgs, 10000Clean Imgs
Generating testbad Imgs


100%|██████████████████████████████████| 10000/10000 [00:00<00:00, 19847.11it/s]
[2024/01/22 18:28:35] - ----------- Data Initialization --------------
[2024/01/22 18:28:35] - ----------- Model Initialization --------------


Injecting Over: 9000Bad Imgs, 1000Clean Imgs


[2024/01/22 18:28:37] - ----------- Model Exposing Strategy --------------
[2024/01/22 18:28:37] - Namespace(target_label=0, trigger_type='gridTrigger', target_type='all2one', trig_w=3, trig_h=3, dataset='CIFAR10', ratio=0.01, batch_size=128, num_classes=10, backdoor_model_path='weights/ResNet18-ResNet-BadNets-target0-portion0.1-epoch80.tar', output_model_path=None, arch='resnet18', print_every=500, unlearn_epochs=20, lr=0.002, sched_gamma=0.1, sched_ms=[20, 20], stop_acc=0.1, device='cuda')


full_acc: {'epoch': 0, 'lr': 1e-05, 'acc': 0.97, 'asr': 1.0, 'cls_pred': [0.116, 0.09, 0.112, 0.126, 0.092, 0.078, 0.1, 0.106, 0.088, 0.092]}
full_acc: {'epoch': 1, 'lr': 1e-05, 'acc': 0.97, 'asr': 1.0, 'cls_pred': [0.114, 0.092, 0.116, 0.128, 0.094, 0.07, 0.102, 0.108, 0.09, 0.086]}
full_acc: {'epoch': 2, 'lr': 1e-05, 'acc': 0.98, 'asr': 1.0, 'cls_pred': [0.112, 0.092, 0.114, 0.13, 0.092, 0.072, 0.102, 0.106, 0.092, 0.088]}
full_acc: {'epoch': 3, 'lr': 1e-05, 'acc': 0.96, 'asr': 1.0, 'cls_pred': [0.114, 0.092, 0.11, 0.13, 0.098, 0.066, 0.102, 0.108, 0.092, 0.088]}
full_acc: {'epoch': 4, 'lr': 1e-05, 'acc': 0.96, 'asr': 1.0, 'cls_pred': [0.124, 0.092, 0.118, 0.126, 0.092, 0.07, 0.102, 0.106, 0.086, 0.084]}
full_acc: {'epoch': 5, 'lr': 1e-05, 'acc': 0.96, 'asr': 1.0, 'cls_pred': [0.114, 0.094, 0.12, 0.136, 0.092, 0.068, 0.096, 0.106, 0.088, 0.086]}
full_acc: {'epoch': 6, 'lr': 1e-05, 'acc': 0.94, 'asr': 1.0, 'cls_pred': [0.124, 0.094, 0.124, 0.134, 0.092, 0.062, 0.098, 0.106, 0.088, 0.0