In [4]:
!pip install kronfluence
!pip install datasets
!pip install timm
!pip install torchmetrics



In [5]:
import torch
import os
import sys
import timm
import numpy as np
from os import makedirs
from os.path import exists
helper_dir = "/data/andy_lee/github/test/corrective-unlearning-bench/src"
sys.path.append(helper_dir)
from utils import seed_everything, SubsetSequentialSampler, get_targeted_classes
import resnet, methods

In [6]:
sys.path.append(helper_dir)
from dataset import load_dataset, DatasetWrapper, manip_dataset, get_deletion_set

In [7]:
torch.multiprocessing.set_sharing_strategy('file_system')
seed_everything(seed=3017)
assert(torch.cuda.is_available())

In [8]:
 # Main Arguments
dataset = 'CIFAR10' # type=str, default='CIFAR10', choices=['CIFAR10','CIFAR100','PCAM', 'LFWPeople', 'CelebA', 'DermNet', 'Pneumonia'])
model ='resnet9' # type=str, default='resnet9', choices=['resnet9', 'resnet20', 'resnet32', 'resnet44', 'resnet56', 'resnet110', 'resnetwide28x10', 'vitb16'])
dataset_method = 'poisoning' # type=str, default='poisoning', choices=['randomlabelswap', 'interclasslabelswap', 'poisoning']
# 'poisoning': all with trigger -> class=0
unlearn_method = 'FlippingInfluence' # type=str, default='Naive', choices=['Naive', 'EU', 'CF', 'Scrub', 'BadT', 'SSD', 'ActivationClustering', 'SpectralSignature', 'InfluenceFunction'], help='Method for unlearning')
num_classes = 10 # type=int, default=10, choices=[2, 10, 100], help='Number of Classes')
forget_set_size = 500 # type=int, default=500, help='Number of samples to be manipulated')
patch_size = 3 # type=int, default=3, help='Creates a patch of size patch_size x patch_size for poisoning at bottom right corner of image')
deletion_size = None # type=int, default=None, help='Number of samples to be deleted')

# Optimizer Params
batch_size = 512 # type=int, default=512, help='input batch size for training (default: 128)')
pretrain_iters = 7500 # type=int, default=7500, help='number of epochs to train (default: 31)')
pretrain_lr = 0.025 # type=float, default=0.025, help='learning rate (default: 0.025)')
wd = 0.0005 # type=float, default=0.0005, help='learning rate (default: 0.01)')
# unlearn
unlearn_iters = 1000 # type=int, default=1000, help='number of epochs to train (default: 31)')
unlearn_lr =  0.025 # type=float, default=0.025, help='learning rate (default: 0.025)')

k = -1
kd_T = 4
alpha = 0.001
msteps = 400

# Defaults
data_dir = '/data/andy_lee/github/test/corrective-unlearning-bench/files/data/' # type=str, default='../data/')
save_dir = '/data/andy_lee/github/test/corrective-unlearning-bench/files/logs/' # type=str, default='../logs/')
exp_name = 'unlearn' # type=str, default='unlearn')
device = 'cuda' # type=str, default='cuda')

In [9]:
class args_specify:
  def __init__(
        self,
        dataset,
        model,
        dataset_method,
        unlearn_method,
        num_classes,
        forget_set_size,
        patch_size,
        deletion_size,
        batch_size,
        pretrain_iters,
        pretrain_lr,
        wd,
        unlearn_iters,
        unlearn_lr,
        data_dir,
        save_dir,
        k,
        kd_T,
        alpha,
        msteps,
        exp_name,
        device="cuda" if torch.cuda.is_available() else "cpu",
    ):
        self.dataset = dataset
        self.model = model
        self.dataset_method = dataset_method
        self.unlearn_method = unlearn_method
        self.num_classes = num_classes
        self.forget_set_size = forget_set_size
        self.patch_size = patch_size
        self.deletion_size = deletion_size
        self.batch_size = batch_size
        self.pretrain_iters = pretrain_iters
        self.pretrain_lr = pretrain_lr
        self.wd = wd
        self.unlearn_iters = unlearn_iters
        self.unlearn_lr = unlearn_lr
        self.data_dir = data_dir
        self.save_dir = save_dir
        self.exp_name = exp_name
        self.k = k
        self.kd_T = kd_T
        self.alpha = alpha
        self.msteps = msteps
        self.device = device
opt = args_specify(
    dataset,
    model,
    dataset_method,
    unlearn_method,
    num_classes,
    forget_set_size,
    patch_size,
    deletion_size,
    batch_size,
    pretrain_iters,
    pretrain_lr,
    wd,
    unlearn_iters,
    unlearn_lr,
    data_dir,
    save_dir,
    k,
    kd_T,
    alpha,
    msteps,
    exp_name,
    device
)

In [10]:
opt.unlearn_lr

0.025

In [11]:
opt.model, opt.dataset_method, opt.forget_set_size

('resnet9', 'poisoning', 500)

In [12]:
# Get model
model = getattr(resnet, opt.model)(opt.num_classes).cuda()
# model
# Get dataloaders done
train_set, train_noaug_set, test_set, train_labels, max_val = load_dataset(dataset=opt.dataset, root=opt.data_dir)
test_loader = torch.utils.data.DataLoader(test_set, batch_size=opt.batch_size, shuffle=False, num_workers=4, pin_memory=True)
manip_dict, manip_idx, untouched_idx = manip_dataset(dataset=opt.dataset, train_labels=train_labels, method=opt.dataset_method, manip_set_size=opt.forget_set_size, save_dir=opt.save_dir)
# manip_idx_path = save_dir+'/'+dataset+'_'+method+'_'+str(manip_set_size)+'_manip.npy'
print('==> Loaded the dataset! (clean)')

Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
==> Loaded the dataset! (clean)


In [13]:
wtrain_noaug_cleanL_set = DatasetWrapper(train_noaug_set, manip_dict, mode='test')
train_test_loader = torch.utils.data.DataLoader(wtrain_noaug_cleanL_set, batch_size=opt.batch_size, shuffle=False, num_workers=4, pin_memory=True)
untouched_noaug_cleanL_loader = torch.utils.data.DataLoader(wtrain_noaug_cleanL_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(untouched_idx), num_workers=4, pin_memory=True)
manip_noaug_cleanL_loader = torch.utils.data.DataLoader(wtrain_noaug_cleanL_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(manip_idx), num_workers=4, pin_memory=True)
eval_loaders = {}
if opt.dataset_method == 'poisoning':
    corrupt_val = np.array(max_val)
    corrupt_size = opt.patch_size
    wtrain_noaug_adv_cleanL_set = DatasetWrapper(train_noaug_set, manip_dict, mode='test_adversarial', corrupt_val=corrupt_val, corrupt_size=corrupt_size)
    adversarial_train_loader = torch.utils.data.DataLoader(wtrain_noaug_adv_cleanL_set, batch_size=opt.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    untouched_noaug_cleanL_loader = torch.utils.data.DataLoader(wtrain_noaug_adv_cleanL_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(untouched_idx), num_workers=4, pin_memory=True)
    manip_noaug_cleanL_loader = torch.utils.data.DataLoader(wtrain_noaug_adv_cleanL_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(manip_idx), num_workers=4, pin_memory=True)
    wtest_adv_cleanL_set = DatasetWrapper(test_set, manip_dict, mode='test_adversarial', corrupt_val=corrupt_val, corrupt_size=corrupt_size)
    adversarial_test_loader = torch.utils.data.DataLoader(wtest_adv_cleanL_set, batch_size=opt.batch_size, shuffle=True, num_workers=4, pin_memory=True)
    eval_loaders['adv_test'] = adversarial_test_loader
else:
    adversarial_train_loader, adversarial_test_loader, corrupt_val, corrupt_size = None, None, None, None

In [14]:
eval_loaders['manip'] = manip_noaug_cleanL_loader
if opt.dataset_method == 'interclasslabelswap':
    classes = get_targeted_classes(opt.dataset)
    indices = []
    for batch_idx, (data, target) in enumerate(test_loader):
        matching_indices = (target == classes[0]) | (target == classes[1])
        absolute_indices = batch_idx * test_loader.batch_size + torch.where(matching_indices)[0]
        indices.extend(absolute_indices.tolist())
    eval_loaders['unseen_forget'] = torch.utils.data.DataLoader(test_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(indices), num_workers=4, pin_memory=True)

wtrain_manip_set = DatasetWrapper(train_set, manip_dict, mode='pretrain', corrupt_val=corrupt_val, corrupt_size=corrupt_size)
pretrain_loader = torch.utils.data.DataLoader(wtrain_manip_set, batch_size=opt.batch_size, shuffle=True, num_workers=4, pin_memory=True)

Pretraining

In [15]:
# Stage 1: Pretraining
opt.pretrain_file_prefix = opt.save_dir+'/'+opt.dataset+'_'+opt.model+'_'+opt.dataset_method+'_'+str(opt.forget_set_size)+'_'+str(opt.patch_size)+'_'+str(opt.pretrain_iters)+'_'+str(opt.pretrain_lr)
if not exists(opt.pretrain_file_prefix):makedirs(opt.pretrain_file_prefix)

if not exists(opt.pretrain_file_prefix + '/Naive_pretrainmodel/model.pth'):
    opt.max_lr, opt.train_iters, expname, unlearn_method = opt.pretrain_lr, opt.pretrain_iters, opt.exp_name, opt.unlearn_method
    
    #We now actually pretrain by calling unlearn(), misnomer
    opt.unlearn_method, opt.exp_name = 'Naive', 'pretrainmodel'
    method = getattr(methods, opt.unlearn_method)(opt=opt, model=model)
    method.unlearn(train_loader=pretrain_loader, test_loader=test_loader)
    method.compute_and_save_results(train_test_loader, test_loader, adversarial_train_loader, adversarial_test_loader)
    opt.exp_name, opt.unlearn_method = expname, unlearn_method  
else:
    print('==> Loading the pretrained model!')
    model.load_state_dict(torch.load(opt.pretrain_file_prefix + '/Naive_pretrainmodel/model.pth'))
    model.to(opt.device)
    print('==> Loaded the pretrained model!')

==> Loading the pretrained model!
==> Loaded the pretrained model!


In [16]:
#forget set
if opt.deletion_size is None:
    opt.deletion_size = opt.forget_set_size
forget_idx, retain_idx = get_deletion_set(opt.deletion_size, manip_dict, train_size=len(train_labels), dataset=opt.dataset, method=opt.dataset_method, save_dir=opt.save_dir)    
opt.max_lr, opt.train_iters = opt.unlearn_lr, opt.unlearn_iters 
if opt.deletion_size != len(manip_dict):
    delete_noaug_cleanL_loader = torch.utils.data.DataLoader(wtrain_noaug_cleanL_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(forget_idx), num_workers=4, pin_memory=True)
    if opt.dataset_method == 'poisoning':
        delete_noaug_cleanL_loader = torch.utils.data.DataLoader(wtrain_noaug_adv_cleanL_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetSequentialSampler(forget_idx), num_workers=4, pin_memory=True)
    eval_loaders['delete'] = delete_noaug_cleanL_loader

In [17]:
from torch.utils.data.sampler import SubsetRandomSampler

In [18]:
opt.unlearn_lr, opt.unlearn_iters, opt.unlearn_method, opt.max_lr

(0.025, 1000, 'FlippingInfluence', 0.025)

In [19]:
# Stage 2: Unlearning
method = getattr(methods, 'ApplyK')(opt=opt, model=model) if opt.unlearn_method in ['EU', 'CF'] else getattr(methods, opt.unlearn_method)(opt=opt, model=model)

wtrain_delete_set = DatasetWrapper(train_set, manip_dict, mode='pretrain', corrupt_val=corrupt_val, corrupt_size=corrupt_size, delete_idx=forget_idx)
# Get the dataloaders
retain_loader = torch.utils.data.DataLoader(wtrain_delete_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetRandomSampler(retain_idx), num_workers=4, pin_memory=True)
train_loader = torch.utils.data.DataLoader(wtrain_delete_set, batch_size=opt.batch_size, shuffle=True, num_workers=4, pin_memory=True)
train_loader_no_shuffle = torch.utils.data.DataLoader(wtrain_delete_set, batch_size=opt.batch_size, shuffle=False, num_workers=4, pin_memory=True)
forget_loader = torch.utils.data.DataLoader(wtrain_delete_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetRandomSampler(forget_idx), num_workers=4, pin_memory=True)

In [20]:
# all poisons
print(f"There are in total {len(forget_idx)} poisons")
# assume we know 10% of these poisons 
# deletion set
known_percentage = 0.1
print(f"Randomly select {known_percentage * 100}% of them to form the deletion set...")
forget_idx_np = forget_idx.numpy()
np.random.shuffle(forget_idx_np)
delete_idx = torch.tensor(forget_idx_np[:int(known_percentage * len(forget_idx_np))])
delete_loader = torch.utils.data.DataLoader(wtrain_delete_set, batch_size=opt.batch_size, shuffle=False, sampler=SubsetRandomSampler(delete_idx), num_workers=4, pin_memory=True)

There are in total 500 poisons
Randomly select 10.0% of them to form the deletion set...


In [21]:
# start detection & unlearning
if opt.unlearn_method in ['Naive', 'EU', 'CF', 'SpectralSignature']:
    method.unlearn(train_loader=retain_loader, test_loader=test_loader, eval_loaders=eval_loaders)
elif opt.unlearn_method in ['BadT']:
    method.unlearn(train_loader=train_loader, test_loader=test_loader, eval_loaders=eval_loaders)
elif opt.unlearn_method in ['Scrub', 'SSD', 'ActivationClustering']:
    method.unlearn(train_loader=retain_loader, test_loader=test_loader, forget_loader=forget_loader, eval_loaders=eval_loaders)
elif opt.unlearn_method in ['InfluenceFunction']:
    method.unlearn(train_loader=train_loader, test_loader=test_loader)
elif opt.unlearn_method in ['FlippingInfluence']:
    # save detected indices
    save_dir = '/data/andy_lee/github/test/corrective-unlearning-bench/files/poison_indices.npy'
    n_tolerate = 2
    method.unlearn(n_tolerate = n_tolerate, train_loader=train_loader_no_shuffle, test_loader=test_loader, deletion_loader=delete_loader, save_dir=save_dir) # no shuffle

method.compute_and_save_results(train_test_loader, test_loader, adversarial_train_loader, adversarial_test_loader)
print('==> Experiment completed! Exiting..')

Fitting covariance matrices [1000/1000] 100%|██████████ [time left: 00:00, time spent: 00:25]
Performing Eigendecomposition [9/9] 100%|██████████ [time left: 00:00, time spent: 00:00]
Fitting Lambda matrices [1000/1000] 100%|██████████ [time left: 00:00, time spent: 00:40]
Computing pairwise scores (training gradient) [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise scores (query gradient) [1/1] 100%|██████████ [time left: 00:00, time spent: 00:01]
Computing pairwise scores (training gradient) [13/13] 100%|██████████ [time left: 00:00, time spent: 00:09]
Computing pairwise scores (query gradient) [1/1] 100%|██████████ [time left: 00:00, time spent: 00:09]
Computing pairwise scores (training gradient) [1/1] 100%|██████████ [time left: 00:00, time spent: 00:00]
Computing pairwise scores (query gradient) [1/1] 100%|██████████ [time left: 00:00, time spent: 00:01]
Computing pairwise scores (training gradient) [13/13] 100%|██████████ [time left: 00:00, time spe

remove samples: (533)
45059
6
12294
4106
34827
45067
49176
32
38
26662
6185
49194
4147
38964
57
12353
20547
10321
45146
36955
4190
45165
14456
28792
45179
22652
41085
10370
4229
26757
47244
6292
22676
4246
2199
155
18596
22693
45224
49322
4268
173
14511
4273
14517
6326
14519
22711
47294
6337
12481
39108
30920
43209
28877
10446
45261
6353
4306
16600
35035
220
8413
24798
49374
32997
14566
10472
22761
49387
47340
8431
24817
45298
33014
24828
41213
35072
43274
45322
41229
20750
271
22806
12579
41258
47405
39219
10550
8503
28983
12611
14659
33094
37191
39245
49485
6489
47449
6497
39275
39283
45429
33143
383
2437
29065
22923
33181
6560
14754
37284
6574
16818
25012
22965
29110
16823
27064
33206
449
18883
47555
22981
49610
35282
10707
39382
14811
20958
23014
47591
27113
39402
31211
29167
2558
8721
12820
45590
537
18976
45602
4643
39461
10793
10797
2607
25135
8757
29238
33333
6712
47674
47677
21062
19024
16979
8788
43609
12891
31325
39526
43624
41577
33387
17004
33396
47743
14976
6787
33415
293

100%|██████████| 97/97 [00:01<00:00, 50.16it/s]


Step: 97 Train Top1: 0.992


100%|██████████| 20/20 [00:00<00:00, 29.11it/s]


Step: 97 Val Top1: 85.57


100%|██████████| 97/97 [00:01<00:00, 56.25it/s]


Step: 194 Train Top1: 0.962


100%|██████████| 20/20 [00:00<00:00, 30.53it/s]


Step: 194 Val Top1: 87.57


100%|██████████| 97/97 [00:01<00:00, 55.90it/s]


Step: 291 Train Top1: 0.987


100%|██████████| 20/20 [00:00<00:00, 29.76it/s]


Step: 291 Val Top1: 89.98


100%|██████████| 97/97 [00:01<00:00, 56.14it/s]


Step: 388 Train Top1: 0.999


100%|██████████| 20/20 [00:00<00:00, 29.35it/s]


Step: 388 Val Top1: 91.25


100%|██████████| 97/97 [00:01<00:00, 56.38it/s]


Step: 485 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 29.71it/s]


Step: 485 Val Top1: 91.68


100%|██████████| 97/97 [00:01<00:00, 54.86it/s]


Step: 582 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 29.50it/s]


Step: 582 Val Top1: 91.71


100%|██████████| 97/97 [00:01<00:00, 56.05it/s]


Step: 679 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 29.35it/s]


Step: 679 Val Top1: 91.76


100%|██████████| 97/97 [00:01<00:00, 56.42it/s]


Step: 776 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 29.74it/s]


Step: 776 Val Top1: 91.69


100%|██████████| 97/97 [00:01<00:00, 55.29it/s]


Step: 873 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 28.41it/s]


Step: 873 Val Top1: 91.80


100%|██████████| 97/97 [00:01<00:00, 54.96it/s]


Step: 970 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 27.97it/s]


Step: 970 Val Top1: 91.74


 31%|███       | 30/97 [00:00<00:01, 52.58it/s]


Step: 1001 Train Top1: 1.000


100%|██████████| 20/20 [00:00<00:00, 27.81it/s]


Step: 1001 Val Top1: 91.72
/data/andy_lee/github/test/corrective-unlearning-bench/files/logs//CIFAR10_resnet9_poisoning_500_3_7500_0.025/500_FlippingInfluence_unlearn_1000_-1_4_0.001_400
==> Completed! Unlearning Time: [0.000]	


100%|██████████| 98/98 [00:02<00:00, 35.48it/s]
100%|██████████| 20/20 [00:00<00:00, 21.32it/s]
100%|██████████| 98/98 [00:02<00:00, 33.12it/s]
100%|██████████| 20/20 [00:00<00:00, 20.31it/s]

==> Experiment completed! Exiting..





In [25]:
# load detected indices
detected_indices = np.load(save_dir, allow_pickle=True)
detected_indices.shape

(533,)

In [26]:
# candidate poisons given by the algorithm
detected_indices

array([    6,    32,    38,    57,   155,   173,   220,   271,   383,
         449,   537,   662,   743,   750,   784,   825,   915,   956,
         969,   975,   987,  1085,  1163,  1183,  1246,  1381,  1521,
        1570,  1604,  1713,  1714,  1790,  1841,  1903,  1936,  1959,
        1972,  2199,  2437,  2558,  2607,  2711,  2751,  2773,  3075,
        3080,  3136,  3225,  3372,  3470,  3482,  3511,  3635,  3788,
        3815,  3816,  3905,  3979,  4064,  4106,  4147,  4190,  4229,
        4246,  4268,  4273,  4306,  4643,  4950,  4953,  4959,  5038,
        5197,  5252,  5399,  5459,  5815,  5846,  5928,  5940,  6126,
        6185,  6292,  6326,  6337,  6353,  6489,  6497,  6560,  6574,
        6712,  6787,  6839,  7181,  7270,  7330,  7342,  7377,  7419,
        7654,  7687,  7705,  7728,  7774,  7795,  8004,  8144,  8162,
        8413,  8431,  8503,  8721,  8757,  8788,  8866,  9069,  9409,
        9485,  9593,  9598,  9630,  9760,  9961, 10008, 10076, 10115,
       10205, 10321,

In [30]:
# add those known poisons
indices_to_be_removed = np.union1d(detected_indices, delete_idx)
# true positives
true_positives_idx = np.setdiff1d(manip_idx, np.setdiff1d(manip_idx, indices_to_be_removed))
len(true_positives_idx), true_positives_idx

(491,
 array([    6,    32,    38,    57,   155,   173,   220,   271,   383,
          449,   537,   662,   743,   750,   783,   784,   825,   915,
          956,   969,   975,   987,  1085,  1163,  1183,  1246,  1521,
         1570,  1604,  1713,  1714,  1790,  1841,  1903,  1936,  1959,
         1972,  2199,  2437,  2558,  2607,  2711,  2751,  2773,  3080,
         3136,  3225,  3372,  3470,  3482,  3511,  3635,  3788,  3815,
         3816,  3905,  3979,  4064,  4106,  4190,  4229,  4246,  4268,
         4306,  4643,  4950,  4953,  4959,  5038,  5197,  5252,  5399,
         5459,  5815,  5846,  5928,  5940,  6126,  6185,  6292,  6326,
         6337,  6353,  6489,  6497,  6560,  6574,  6712,  6787,  6839,
         7181,  7330,  7342,  7419,  7654,  7705,  7728,  7774,  7795,
         8004,  8144,  8162,  8413,  8431,  8503,  8721,  8757,  8788,
         8866,  9409,  9485,  9593,  9598,  9630,  9760,  9961, 10076,
        10205, 10321, 10370, 10446, 10472, 10550, 10707, 10797, 10920,


In [31]:
# TPR
len(true_positives_idx) / 500

0.982