In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"

In [2]:
import pickle
import time
import tensorflow as tf
import numpy as np

from dl_spectral_normalization import dl_utils
from dl_spectral_normalization import adversarial as ad

%matplotlib inline

# Load dataset

We provide the code for downloading and loading one of three types of datasets:
- CIFAR10
- MNIST
- SVHN

In [3]:
# CIFAR10
from get_cifar10 import get_cifar10_dataset
Xtr, Ytr, Xtt, Ytt = get_cifar10_dataset(0, n_samps=50000)
val_set = {'X': Xtt[:500], 'Y': Ytt[:500]}
Xtt, Ytt = Xtt[500:], Ytt[500:]

In [None]:
# MNIST
# NOTE: If you use MNIST, a lot of the dl_spectral_normalization functions 
# will require you to set num_channels=1
from tensorflow.examples.tutorials.mnist import input_data
mnist = input_data.read_data_sets("MNIST_data/", one_hot=False)
Xtr = mnist.train.images.reshape(-1, 28, 28, 1)
Ytr = mnist.train.labels.astype(float)
Xtt = mnist.test.images.reshape(-1, 28, 28, 1)
Ytt = mnist.test.labels.astype(float)
val_set = {'X': Xtt[:500], 'Y': Ytt[:500]}
Xtt, Ytt = Xtt[500:], Ytt[500:]

In [None]:
# SVHN
from get_cifar10 import get_svhn_dataset
Xtr, Ytr, Xtt, Ytt = get_svhn_dataset(0)
val_set = {'X': Xtt[:500], 'Y': Ytt[:500]}
Xtt, Ytt = Xtt[500:], Ytt[500:]

# Select network

Please see `spectral_adversarial_regularization/models` for the full list of models provided. We give examples of networks trained in the paper. 

In [4]:
from dl_spectral_normalization.models import alexnet as model
arch = model.alexnet_sn

# Train network

In [5]:
def train_network(Xtr, Ytr, val_set, arch, save_dir, 
                  beta=1,
                  adv=None, order=2, eps=0.3,
                  opt='momentum', lr_initial=0.01,
                  num_epochs=200, save_every=25,
                  gpu_prop=0.3, retrain=False):
    """
    Wrapper for training a network using dl_spectral_normalization
    
    Inputs
    ----------------------------------------------------------------------
        Xtr:           Training data (channels last format)
        Ytr:           Training labels (not one-hot encoded)
        val_set:       Dict with keys 'X' and 'Y' for the validation
                       data and labels
        arch:          One of the architectures from
                       dl_spectral_normalization/models
        save_dir:      Directory to save weights and TensorBoard logs
        beta:          Amount of spectral normalization (max spectral 
                       norm of a layer). Use beta=np.inf for no
                       spectral normalization
        adv:           Adversarial training scheme: 'erm', 'fgm', 
                       'pgm', or 'wrm'
        order:         Order of attack (np.inf, 1, or 2) for FGM or PGM
        eps:           Magnitude of attack during training
        opt:           Optimizer type ('adam' or 'momentum')
        lr_initial:    Initial learning rate
        num_epochs:    Number of epochs to train for
        save_every:    Save weights every this many epochs
        gpu_prop:      Proportion of GPU to allocate for training process
        retrain:       Whether or not to delete the existing weights and
                       retrain the network
    """
    
    if os.path.isdir(save_dir): 
        if retrain: os.system('rm -rf %s'%(save_dir))
        else: return

    print('eps = %.4f, saving weights to %s'%(eps, save_dir))
    _ = dl_utils.build_graph_and_train(Xtr, Ytr, save_dir, arch,
                                       val_set=val_set,
                                       num_channels=Xtr.shape[-1],
                                       beta=beta,
                                       adv=adv, order=order, eps=eps,
                                       opt=opt, lr_initial=lr_initial,
                                       num_epochs=num_epochs, save_every=save_every,
                                       gpu_prop=gpu_prop, 
                                       batch_size=128,
                                       early_stop_acc=0.999,
                                       early_stop_acc_num=5)

In [6]:
# Directory in which we save weights
dirname = '/data/save_weights_tf1.10.1/cifar10/alexnet/'

# List of betas to sweep through (np.inf means no spectral normalization)
beta_list = np.array([np.inf, 1.0, 1.3, 1.6, 2.0, 4.0])

# Specify the amount of perturbation to use during training
C2 = np.mean([np.sqrt(np.sum(np.square(i))) for i in Xtr])
gamma = 0.002*C2 # for MNIST, use 0.04*C2
eps_wrm = 1./(2*gamma)
eps = 0.05*C2

In [7]:
# ERM
for beta in beta_list:
    save_dir = os.path.join(dirname, 'erm_beta%s'%(beta))
    train_network(Xtr, Ytr, val_set, arch, save_dir, adv='erm', beta=beta)
    
# FGM
for beta in beta_list:
    save_dir = os.path.join(dirname, 'fgm_beta%s'%(beta))
    train_network(Xtr, Ytr, val_set, arch, save_dir, adv='fgm', beta=beta, eps=eps)
    
# PGM
for beta in beta_list:
    save_dir = os.path.join(dirname, 'pgm_beta%s'%(beta))
    train_network(Xtr, Ytr, val_set, arch, save_dir, adv='pgm', beta=beta, eps=eps)
    
# WRM
for beta in beta_list:
    save_dir = os.path.join(dirname, 'wrm_beta%s'%(beta))
    train_network(Xtr, Ytr, val_set, arch, save_dir, adv='wrm', beta=beta, eps=eps_wrm)

# Test trained networks with various-magnitude attacks

In [8]:
def generate_adv_attack_curves(X, Y, arch, eps_list, defense, attack, resultsfile, beta_list, dirname,
                               load_epoch=None, num_channels=3, order=2, opt='momentum'):
    if os.path.isfile(resultsfile):
        adv_results = pickle.load(file(resultsfile, 'rb'))
    else:
        adv_results = {}
        
    for beta in beta_list:
        if beta in adv_results: continue
        save_dir = os.path.join(dirname, '%s_beta%s'%(defense, beta))
        
        adv_accs = np.zeros(len(eps_list))
        for i, eps in enumerate(eps_list):
            adv_accs[i] = ad.test_net_against_adv_examples(X, Y, save_dir, arch, 
                                                           beta=beta, method=attack,
                                                           load_epoch=load_epoch,
                                                           num_channels=num_channels,
                                                           order=order,
                                                           opt=opt)
        adv_results[beta] = adv_accs
        pickle.dump(adv_results, file(resultsfile, 'wb'))
        
    return adv_results

In [9]:
# Eps attack values to sweep over
eps_list = np.linspace(0, 5, 6)

In [10]:
resultsfile = os.path.join(dirname, 'erm_defense_pgm_attacks_testset.pickle')
adv_results = generate_adv_attack_curves(Xtt, Ytt, arch, eps_list, 'erm', ad.pgm,
                                         resultsfile, beta_list, dirname)

resultsfile = os.path.join(dirname, 'fgm_defense_fgm_attacks_testset.pickle')
adv_results = generate_adv_attack_curves(Xtt, Ytt, arch, eps_list, 'fgm', ad.fgm,
                                         resultsfile, beta_list, dirname)

resultsfile = os.path.join(dirname, 'pgm_defense_pgm_attacks_testset.pickle')
adv_results = generate_adv_attack_curves(Xtt, Ytt, arch, eps_list, 'pgm', ad.pgm,
                                         resultsfile, beta_list, dirname)

resultsfile = os.path.join(dirname, 'wrm_defense_wrm_attacks_testset.pickle')
adv_results = generate_adv_attack_curves(Xtt, Ytt, arch, eps_list, 'wrm', ad.wrm,
                                         resultsfile, beta_list, dirname)