In [1]:
%load_ext autoreload
%autoreload 2

import numpy as np
import os
import argparse
import time
import torch
import torch.nn as nn
import torch.backends.cudnn as cudnn
import torchvision.transforms as trn
import torchvision.datasets as dset
import torch.nn.functional as F
from tqdm.notebook import tqdm
from utils.wrn import WideResNet
import utils.attacks as attacks
from utils.detector import ScoreDetector
from tqdm.notebook import tqdm

## Args

In [2]:
parser = argparse.ArgumentParser(description='Trains a CIFAR Classifier',
                                 formatter_class=argparse.ArgumentDefaultsHelpFormatter)
parser.add_argument('--dataset', '-d', type=str, default='cifar10', choices=['cifar10', 'cifar100'],
                    help='Choose between CIFAR-10, CIFAR-100.')
parser.add_argument('--model', '-m', type=str, default='wrn',
                    choices=['allconv', 'wrn'], help='Choose architecture.')
# Optimization options
parser.add_argument('--epochs', '-e', type=int, default=100, help='Number of epochs to train.')
parser.add_argument('--learning_rate', '-lr', type=float, default=0.1, help='The initial learning rate.')
parser.add_argument('--batch_size', '-b', type=int, default=128, help='Batch size.')
parser.add_argument('--test_bs', type=int, default=256)
parser.add_argument('--momentum', type=float, default=0.9, help='Momentum.')
parser.add_argument('--decay', type=float, default=0.0005, help='Weight decay (L2 penalty).')
# WRN Architecture
parser.add_argument('--layers', default=40, type=int, help='total number of layers')
parser.add_argument('--widen-factor', default=2, type=int, help='widen factor')
parser.add_argument('--droprate', default=0.0, type=float, help='dropout probability')
# Checkpoints
parser.add_argument('--save', '-s', type=str, default='./snapshots/baseline', help='Folder to save checkpoints.')
parser.add_argument('--load', '-l', type=str, default='', help='Checkpoint path to resume / test.')
parser.add_argument('--test', '-t', action='store_true', help='Test only flag.')
# Acceleration
parser.add_argument('--ngpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--gpu', type=int, default=1, help='0 = CPU.')
parser.add_argument('--prefetch', type=int, default=2, help='Pre-fetching threads.')
args = parser.parse_args(["--save", "checkpoints_score_v2/", "--gpu", "3", "-b", "128", "--test_bs", "256"])

## Initialization

In [3]:
state = {k: v for k, v in args._get_kwargs()}
print(state)

torch.manual_seed(1)
np.random.seed(1)

# # mean and standard deviation of channels of CIFAR-10 images
# mean = [x / 255 for x in [125.3, 123.0, 113.9]]
# std = [x / 255 for x in [63.0, 62.1, 66.7]]

train_transform = trn.Compose([trn.RandomHorizontalFlip(), trn.RandomCrop(32, padding=4),
                               trn.ToTensor()])
test_transform = trn.Compose([trn.ToTensor()])

train_data = dset.CIFAR10('~/datasets/cifarpy', train=True, transform=train_transform, download=True)
test_data = dset.CIFAR10('~/datasets/cifarpy', train=False, transform=test_transform, download=True)
num_classes = 10

train_loader = torch.utils.data.DataLoader(
    train_data, batch_size=args.batch_size, shuffle=True,
    num_workers=args.prefetch, pin_memory=True)
test_loader = torch.utils.data.DataLoader(
    test_data, batch_size=args.test_bs, shuffle=False,
    num_workers=args.prefetch, pin_memory=True)



normalize = trn.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
detector_data_transform = trn.Compose([trn.ToTensor(), normalize])

data_train = list(torch.utils.data.DataLoader(
        dset.CIFAR10('~/datasets/cifarpy', 
                     train=True, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))

data_test = list(torch.utils.data.DataLoader(
        dset.CIFAR10('~/datasets/cifarpy', 
                     train=False, 
                     transform=detector_data_transform, 
                     download=True),
        batch_size=1, shuffle=False))


{'layers': 40, 'momentum': 0.9, 'load': '', 'test': False, 'test_bs': 256, 'save': 'checkpoints_score_v2/', 'droprate': 0.0, 'ngpu': 1, 'learning_rate': 0.1, 'batch_size': 128, 'gpu': 3, 'prefetch': 2, 'widen_factor': 2, 'epochs': 100, 'model': 'wrn', 'dataset': 'cifar10', 'decay': 0.0005}
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified
Files already downloaded and verified


In [4]:
# Create model
if args.model == 'allconv':
    net = AllConvNet(num_classes)
else:
    net = WideResNet(args.layers, num_classes, args.widen_factor, dropRate=args.droprate)
    net.load_state_dict(torch.load("benchmark_ckpts/cifar10_reg_training_99.pt"))

start_epoch = 0

# Restore model if desired
if args.load != '':
    for i in range(1000 - 1, -1, -1):
        model_name = os.path.join(args.load, args.dataset + '_' + args.model +
                                  '_baseline_epoch_' + str(i) + '.pt')
        if os.path.isfile(model_name):
            net.load_state_dict(torch.load(model_name))
            print('Model restored! Epoch:', i)
            start_epoch = i + 1
            break

    if start_epoch == 0:
        assert False, "could not resume"

if args.ngpu > 1:
    net = torch.nn.DataParallel(net, device_ids=list(range(args.ngpu)))

if args.ngpu > 0:
    torch.cuda.set_device(args.gpu)
    net.cuda()
    torch.cuda.manual_seed(1)

# cudnn.benchmark = True  # fire on all cylinders

optimizer = torch.optim.SGD(
    net.parameters(), state['learning_rate'], momentum=state['momentum'],
    weight_decay=state['decay'], nesterov=True)

def cosine_annealing(step, total_steps, lr_max, lr_min):
    return lr_min + (lr_max - lr_min) * 0.5 * (
            1 + np.cos(step / total_steps * np.pi))


scheduler = torch.optim.lr_scheduler.LambdaLR(
    optimizer,
    lr_lambda=lambda step: cosine_annealing(
        step,
        args.epochs * len(train_loader),
        1,  # since lr_lambda computes multiplicative factor
        1e-6 / args.learning_rate))

## Training

In [5]:
score_scale = 1000
max_score = 0.125
adversary = attacks.PGD_score(score_scale = score_scale,
                                epsilon=8./255, 
                                num_steps=10, 
                                step_size=2./255,
                                max_score=max_score,
                                verbose=False)

def train():
    net.train()
    loss_avg, loss_gram_avg = 0.0, 0.0
    i = 0
    
    dt = ScoreDetector()
    
    for bx, by in tqdm(train_loader):
        bx, by = bx.cuda(), by.cuda()
        
        net.eval()
        adv_bx = adversary(net, bx, by)
        net.train()
                        
        # forward
        logits_reg, feats_reg = net.gram_forward(bx * 2 - 1)
        logits_adv, feats_adv = net.gram_forward(adv_bx * 2 - 1)

        # backward
        optimizer.zero_grad()

        loss_reg = F.cross_entropy(logits_reg, by)
        loss_adv = F.cross_entropy(logits_adv, by)
        
        loss_gram = dt.score(feats_adv)
        loss_gram_reg = dt.score(feats_reg)
        
        margin = F.softplus(0.3 + loss_gram_reg - loss_gram, beta=100)
        
        loss = loss_reg + 10 * margin + F.softplus(loss_gram_reg - .125, beta=100)
#         print(loss_gram, 10 * margin, F.softplus(loss_gram_reg - .125, beta=100))
        loss.backward()
                
        optimizer.step()
        scheduler.step()
        
        i += 1
        
        # exponential moving average
        loss_avg = loss_avg * 0.8 + float(loss) * 0.2
        loss_gram_avg = loss_gram_avg * 0.8 + float(loss_gram) * 0.2
    
    state['train_loss'] = loss_avg
    state["gram_train_loss"] = loss_gram_avg
    
    print("Train Loss:", state["train_loss"])
    print("Train Gram: ", state["gram_train_loss"])

# test function
def test():
    net.eval()
        
    acc_reg, acc_adv, auroc = [], [], []
    
    dt = ScoreDetector()
    with torch.no_grad():
        for bx, by in tqdm(test_loader):
            bx, by = bx.cuda(), by.cuda()
            adv_bx = adversary(net, bx, by)

            # forward
            logits_reg, feats_reg = net.gram_forward(bx * 2 - 1)
            logits_adv, feats_adv = net.gram_forward(adv_bx * 2 - 1)
                        
            a = dt.calc_auroc(feats_reg, feats_adv)
            auroc.append(a)
                        
            acc_reg.append((by==torch.max(logits_reg,dim=1)[1]).cpu().numpy().mean())
            acc_adv.append((by==torch.max(logits_adv,dim=1)[1]).cpu().numpy().mean())
            
            
    state['test_accuracy'] = np.mean(acc_reg)
    state["adversarial_accuracy"] = np.mean(acc_adv)
    state['auroc'] = np.mean(auroc)

In [None]:
# Make save directory
if not os.path.exists(args.save):
    os.makedirs(args.save)
if not os.path.isdir(args.save):
    raise Exception('%s is not a dir' % args.save)

with open(os.path.join(args.save, args.dataset + '_' + args.model +
                                  '_baseline_training_results.csv'), 'w') as f:
    f.write('epoch,time(s),train_loss,test_loss,test_error(%),gram_auroc\n')

print('Beginning Training!\n')


# Main loop
for epoch in range(start_epoch, args.epochs):
    state['epoch'] = epoch

    begin_epoch = time.time()
    
    print("1. Training")
    if epoch != 0:
        train()
        
    if epoch % 1 == 0:
        net.eval()
        
        print("2. Testing")
        try:
            test()
        except Exception as e:
            print("Failed test")
            print(e)

        # Save model
        torch.save(net.state_dict(),
                   os.path.join(args.save, args.dataset + '_' + args.model +
                                '_baseline_epoch_' + str(epoch) + '.pt'))
        # Let us not waste space and delete the previous model
        prev_path = os.path.join(args.save, args.dataset + '_' + args.model +
                                 '_baseline_epoch_' + str(epoch - 3) + '.pt')
        if os.path.exists(prev_path): os.remove(prev_path)

        # Show results
        print('Epoch {0:3d} | Time {1:5d} | Adversarial Acc {2:.3f} | Test Error {3:.2f} | Auroc {4:.2f}'.format(
            (epoch + 1),
            int(time.time() - begin_epoch),
            100. * state['adversarial_accuracy'],
            100 - 100. * state['test_accuracy'],
            state["auroc"])
        )

Beginning Training!

1. Training
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   1 | Time    75 | Adversarial Acc 1.982 | Test Error 8.85 | Auroc 0.35
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.8402351846031824
Train Gram:  0.8673814932214721
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   2 | Time   526 | Adversarial Acc 23.975 | Test Error 76.63 | Auroc 0.01
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.1945355896217449
Train Gram:  0.8171536451112773
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   3 | Time   524 | Adversarial Acc 3.789 | Test Error 51.77 | Auroc 0.01
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.0870154956604166
Train Gram:  0.6056019706991955
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   4 | Time   524 | Adversarial Acc 7.871 | Test Error 52.28 | Auroc 0.01
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.9372988877929431
Train Gram:  0.6123679008877403
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   5 | Time   524 | Adversarial Acc 49.189 | Test Error 36.04 | Auroc 0.17
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.0914226718626279
Train Gram:  2.788936739425285
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   6 | Time   524 | Adversarial Acc 37.598 | Test Error 56.56 | Auroc 0.01
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.0835934189011605
Train Gram:  1.7986078349875758
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   7 | Time   524 | Adversarial Acc 42.646 | Test Error 39.74 | Auroc 0.02
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.7617553325139671
Train Gram:  2.6577586311131096
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   8 | Time   524 | Adversarial Acc 30.342 | Test Error 31.53 | Auroc 0.00
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.7228442171382742
Train Gram:  2.7636413404617364
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch   9 | Time   524 | Adversarial Acc 2.725 | Test Error 34.42 | Auroc 0.00
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 0.6271235805593255
Train Gram:  1.0036295578817391
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  10 | Time   524 | Adversarial Acc 0.039 | Test Error 26.54 | Auroc 0.00
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.3344229587441139
Train Gram:  2.636668697859734
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  11 | Time   524 | Adversarial Acc 16.826 | Test Error 35.55 | Auroc 0.06
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))


Train Loss: 1.8813555306243885
Train Gram:  1.0806115278809427
2. Testing


HBox(children=(FloatProgress(value=0.0, max=40.0), HTML(value='')))


Epoch  12 | Time   524 | Adversarial Acc 20.947 | Test Error 70.43 | Auroc 0.08
1. Training


HBox(children=(FloatProgress(value=0.0, max=391.0), HTML(value='')))