In [1]:
import numpy as np
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm.notebook import tqdm
from utils.wrn import WideResNet
import utils.attacks as attacks
from utils.detector import Detector, gram_margin_loss
from tqdm.notebook import tqdm

## Args

In [2]:
parser = argparse.ArgumentParser(description='Trains a CIFAR Classifier',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', '-d', type=str, default='cifar10', choices=['cifar10', 'cifar100'],
                    help='Choose between CIFAR-10, CIFAR-100.')
parser.add_argument('--model', '-m', type=str, default='wrn',
                    choices=['allconv', 'wrn'], help='Choose architecture.')
# Optimization options
parser.add_argument('--epochs', '-e', type=int, default=100, help='Number of epochs to train.')
parser.add_argument('--learning_rate', '-lr', type=float, default=0.1, help='The initial learning rate.')
parser.add_argument('--batch_size', '-b', type=int, default=128, help='Batch size.')
parser.add_argument('--test_bs', type=int, default=256)
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', type=float, default=0.0005, help='Weight decay (L2 penalty).')
# WRN Architecture
parser.add_argument('--layers', default=40, type=int, help='total number of layers')
parser.add_argument('--widen-factor', default=2, type=int, help='widen factor')
parser.add_argument('--droprate', default=0.0, type=float, help='dropout probability')
# Checkpoints
parser.add_argument('--save', '-s', type=str, default='./snapshots/baseline', help='Folder to save checkpoints.')
parser.add_argument('--load', '-l', type=str, default='', help='Checkpoint path to resume / test.')
parser.add_argument('--test', '-t', action='store_true', help='Test only flag.')
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--gpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
args = parser.parse_args(["--save", "checkpoints_margin_01_v2/", "--gpu", "2", "-b", "128"])

## Initialization

In [3]:
state = {k: v for k, v in args._get_kwargs()}
print(state)

torch.manual_seed(1)
np.random.seed(1)

# # mean and standard deviation of channels of CIFAR-10 images
# mean = [x / 255 for x in [125.3, 123.0, 113.9]]
# std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4),
                               trn.ToTensor()])
test_transform = trn.Compose([trn.ToTensor()])

train_data = dset.CIFAR10('~/datasets/cifarpy', train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10('~/datasets/cifarpy', train=False, transform=test_transform, download=True)
num_classes = 10

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)



normalize = trn.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
detector_data_transform = trn.Compose([trn.ToTensor(), normalize])

data_train = list(torch.utils.data.DataLoader(
        dset.CIFAR10('~/datasets/cifarpy', 
                     train=True, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))

data_test = list(torch.utils.data.DataLoader(
        dset.CIFAR10('~/datasets/cifarpy', 
                     train=False, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))


{'gpu': 2, 'layers': 40, 'model': 'wrn', 'batch_size': 128, 'test': False, 'test_bs': 256, 'learning_rate': 0.1, 'dataset': 'cifar10', 'droprate': 0.0, 'decay': 0.0005, 'epochs': 100, 'save': 'checkpoints_margin_01_v2/', 'load': '', 'momentum': 0.9, 'ngpu': 1, 'widen_factor': 2, 'prefetch': 2}
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
    net.load_state_dict(torch.load("benchmark_ckpts/cifar10_reg_training_99.pt"))

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, args.dataset + '_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break

    if start_epoch == 0:
        assert False, "could not resume"

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    torch.cuda.set_device(args.gpu)
    net.cuda()
    torch.cuda.manual_seed(1)

# cudnn.benchmark = True  # fire on all cylinders

optimizer = torch.optim.SGD(
    net.parameters(), state['learning_rate'], momentum=state['momentum'],
    weight_decay=state['decay'], nesterov=True)

def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))


scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_annealing(
        step,
        args.epochs * len(train_loader),
        1,  # since lr_lambda computes multiplicative factor
        1e-6 / args.learning_rate))

## Training

In [5]:
margin = 10
margin_scale = .01
adversary = attacks.PGD_margin(epsilon=8./255, 
                               num_steps=10, 
                               step_size=2./255, 
                               margin=20,
                               margin_scale=1.0).cuda()

def train():
    net.train()
    loss_avg, loss_gram_avg = 0.0, 0.0
    i = 0
    
    for bx, by in tqdm(train_loader):
        bx, by = bx.cuda(), by.cuda()
                        
        # forward
        logits_reg, feats_reg = net.gram_forward(bx * 2 - 1)
        
        adv_bx = adversary(net, bx, by, feats_reg)
        logits_adv, feats_adv = net.gram_forward(adv_bx * 2 - 1)

        # backward
        optimizer.zero_grad()

        loss_reg = F.cross_entropy(logits_reg, by)
#         loss_adv = F.cross_entropy(logits_adv, by)
        
        loss_gram = margin_scale * gram_margin_loss(feats_reg, feats_adv, margin=margin).cuda()
        loss = loss_reg + loss_gram
                                                
        loss.backward()
        
        optimizer.step()
        scheduler.step()
        
        i += 1
        
        # exponential moving average
        loss_avg = loss_avg * 0.8 + float(loss) * 0.2
        loss_gram_avg = loss_gram_avg * 0.8 + float(loss_gram) * 0.2
    
    state['train_loss'] = loss_avg
    state["gram_train_loss"] = loss_gram_avg
    
    print("Train Loss:", state["train_loss"])
    print("Train Gram: ", state["gram_train_loss"])

# test function
def test(detector = None):
    net.eval()
        
    acc_reg, acc_adv, auroc, auroc_failed = [], [], [], []
    with torch.no_grad():
        for bx, by in tqdm(test_loader):
            bx, by = bx.cuda(), by.cuda()

            # forward
            logits_reg, feats_reg = net.gram_forward(bx * 2 - 1)
            
            adv_bx = adversary(net, bx, by, feats_reg)
            logits_adv, feats_adv = net.gram_forward(adv_bx * 2 - 1)
            
            logits_reg, logits_adv, by = logits_reg.cpu(), logits_adv.cpu(), by.cpu()
            
#             loss_reg = F.cross_entropy(logits_reg, by)
#             print(loss_reg)
            
            a, a_f = detector.compute_auroc_advs(logits_adv, feats_adv, by)
            auroc.append(a)
            auroc_failed.append(a_f)
            
            acc_reg.append((by==torch.max(logits_reg,dim=1)[1]).cpu().numpy().mean())
            acc_adv.append((by==torch.max(logits_adv,dim=1)[1]).cpu().numpy().mean())
            
    state['test_accuracy'] = np.mean(acc_reg)
    state["adversarial_accuracy"] = np.mean(acc_adv)
    state['auroc'] = np.mean(auroc)
    state["auroc_failed"] = np.mean(auroc_failed)

In [6]:
# Make save directory
if not os.path.exists(args.save):
    os.makedirs(args.save)
if not os.path.isdir(args.save):
    raise Exception('%s is not a dir' % args.save)

with open(os.path.join(args.save, args.dataset + '_' + args.model +
                                  '_baseline_training_results.csv'), 'w') as f:
    f.write('epoch,time(s),train_loss,test_loss,test_error(%),gram_auroc\n')

print('Beginning Training!\n')

# Main loop
for epoch in range(start_epoch, args.epochs):
    state['epoch'] = epoch

    begin_epoch = time.time()
    
    print("1. Training")
    if True or epoch != 0:
        train()
    print("2. Initializing Detector")
    
    if epoch % 3 == 0:
        net.eval()
        detector = Detector(net, data_train, data_test, args.test_bs, pbar=None)
        print("3. Testing")
        try:
            test(detector)
        except Exception as e:
            print("Failed test")
            print(e)

        # Save model
        torch.save(net.state_dict(),
                   os.path.join(args.save, args.dataset + '_' + args.model +
                                '_baseline_epoch_' + str(epoch) + '.pt'))
        # Let us not waste space and delete the previous model
        prev_path = os.path.join(args.save, args.dataset + '_' + args.model +
                                 '_baseline_epoch_' + str(epoch - 3) + '.pt')
        if os.path.exists(prev_path): os.remove(prev_path)

        # Show results

#         with open(os.path.join(args.save, args.dataset + '_' + args.model +
#                                           '_baseline_training_results.csv'), 'a') as f:
#             f.write('%03d,%05d,%0.6f,%0.5f,%0.2f,%0.2f\n' % (
#                 (epoch + 1),
#                 time.time() - begin_epoch,
#                 state['train_loss'],
#                 state['test_loss'],
#                 100 - 100. * state['test_accuracy'],
#                 state["gram_auroc"],
#             ))

        # # print state with rounded decimals
#         print({k: round(v, 4) if isinstance(v, float) else v for k, v in state.items()})

        print('Epoch {0:3d} | Time {1:5d} | Adversarial Acc {2:.3f} | Test Error {3:.2f} | Auroc {4:.2f} | Auroc Failed {5:.2f}'.format(
            (epoch + 1),
            int(time.time() - begin_epoch),
            100. * state['adversarial_accuracy'],
            100 - 100. * state['test_accuracy'],
            state["auroc"],
            state["auroc_failed"])
        )

Beginning Training!

1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.7913821676130637
Train Gram:  0.19707728419828713
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   1 | Time   816 | Adversarial Acc 4.600 | Test Error 89.03 | Auroc 0.95 | Auroc Failed 0.79
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.6699727184642073
Train Gram:  7.909669701534643e-05
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.5535070585762809
Train Gram:  1.4907454774489498e-05
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.5314681018322593
Train Gram:  0.0007240688786418161
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   4 | Time   852 | Adversarial Acc 46.768 | Test Error 33.52 | Auroc 0.95 | Auroc Failed 0.95
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.45203345042426774
Train Gram:  0.00021674332237194366
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 0.4449444509893572
Train Gram:  0.00877116702957116
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.2057934906505317
Train Gram:  0.5918784362231386
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.44206507859257105
Train Gram:  8.161665327640239e-05
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  16 | Time   851 | Adversarial Acc 37.490 | Test Error 46.72 | Auroc 0.91 | Auroc Failed 0.94
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 0.40817718211110415
Train Gram:  0.0030432729486918916
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  25 | Time   851 | Adversarial Acc 14.277 | Test Error 47.41 | Auroc 1.00 | Auroc Failed 0.99
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.38313715528701586
Train Gram:  2.093297599731018e-09
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.4429411008779613
Train Gram:  0.034749410188698614
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 0.35460031468192005
Train Gram:  0.0025175668598697454
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.3579064772329712
Train Gram:  0.004429865706720382
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  37 | Time   827 | Adversarial Acc 23.008 | Test Error 41.87 | Auroc 1.00 | Auroc Failed 0.98
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.323064867049347
Train Gram:  0.009610328691872227
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 0.3206868879530559
Train Gram:  0.000154877720092354
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.330219128496827
Train Gram:  3.964499742631023e-06
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.2727501449003139
Train Gram:  1.8099934813767777e-09
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  49 | Time   826 | Adversarial Acc 49.912 | Test Error 20.35 | Auroc 0.99 | Auroc Failed 0.97
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 0.2188882919325303
Train Gram:  2.002447153483616e-05
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  58 | Time   828 | Adversarial Acc 46.045 | Test Error 16.86 | Auroc 0.99 | Auroc Failed 0.97
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.28373283238867364
Train Gram:  0.00016574668918438088
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.29765844016547033
Train Gram:  0.00622690873988721
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.2637848323624955
Train Gram:  0.0022536467280243884
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  61 | Time   827 | Adversarial Acc 17.510 | Test Error 28.94 | Auroc 1.00 | Auroc Failed 0.99
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

IOPub message rate exceeded.
The notebook server will temporarily stop sending output
to the client in order to avoid crashing it.
To change this limit, set the config variable
`--NotebookApp.iopub_msg_rate_limit`.

Current values:
NotebookApp.iopub_msg_rate_limit=1000.0 (msgs/sec)
NotebookApp.rate_limit_window=3.0 (secs)




Train Loss: 0.14895798723382056
Train Gram:  0.00023316907033347567
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.1752824242308163
Train Gram:  7.277001096199719e-07
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  70 | Time   830 | Adversarial Acc 37.881 | Test Error 18.25 | Auroc 0.96 | Auroc Failed 0.97
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.1397221987722918
Train Gram:  1.102895119848606e-05
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.1383455114977576
Train Gram:  1.8495879454453853e-05
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.1290866377842466
Train Gram:  4.775666526220871e-06
2. Initializing Detector
3. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  73 | Time   830 | Adversarial Acc 46.084 | Test Error 20.23 | Auroc 0.98 | Auroc Failed 0.99
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.11551447079425942
Train Gram:  2.5700277017296866e-07
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.11126085108826844
Train Gram:  1.9456966026875922e-05
2. Initializing Detector
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))

KeyboardInterrupt: 