In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import numpy as onp
import jax
import jax.numpy as np

from jax import lax, random
from jax.api import grad, jit, vmap
from jax.config import config
from jax.experimental import optimizers
from jax.experimental.stax import logsoftmax

config.update('jax_enable_x64', True)

from neural_tangents import stax

from functools import partial

# Attacking
from cleverhans.utils import clip_eta, one_hot

# Plotting
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import *

sns.set_style(style='white')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

"""
diag_reg:
    a scalar representing the strength of the diagonal regularization for
    `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
    Cholesky factorization or eigendecomposition.
"""
diag_reg = 1e-5

# data

In [2]:
DATASET = 'cifar10'
class_num   = 10
image_shape = None

train_size = 45000
valid_size = 5000
test_size = None
test_batch_size  = 16
eps = 0.03

batch_size = 100
eps = 0.03
epochs = 2000

if DATASET =='mnist':
    image_shape = (28, 28, 1)
elif DATASET == 'cifar10':
    image_shape = (32, 32, 3)

In [3]:
x_train_all, y_train_all, x_test_all, y_test_all = tuple(onp.array(x) for x in get_dataset(DATASET, None, None, 
                                                                                  do_flatten_and_normalize=False))

In [4]:
# shuffle
seed = 0
x_train_all, y_train_all = shaffle(x_train_all, y_train_all, seed)

In [5]:
# down sample
x_train = x_train_all[:train_size]
y_train = y_train_all[:train_size]

x_valid = x_train_all[train_size:train_size + valid_size]
y_valid = y_train_all[train_size:train_size + valid_size]

x_test = x_test_all[:test_size]
y_test = y_test_all[:test_size]

In [6]:
x_train, x_valid, x_test = x_train.reshape((-1, *image_shape)), x_valid.reshape((-1, *image_shape)), x_test.reshape((-1, *image_shape))

# model

In [7]:
def correct(mean, ys):
    return onp.argmax(mean, axis=-1) == onp.argmax(ys, axis=-1)

In [8]:
def ConvBlock(channels, W_std, b_std, strides=(1,1)):
    return stax.serial(stax.Conv(out_chan=channels, filter_shape=(3,3), strides=strides, padding='SAME',
                                 W_std=W_std, b_std=b_std), 
                       stax.Relu(do_backprop=True))

def ConvGroup(n, channels, stride, W_std, b_std, last_stride=False):
    blocks = []
    if last_stride:
        for i in range(n-1):
            blocks += [ConvBlock(channels, W_std, b_std, stride)]
        blocks += [ConvBlock(channels, W_std, b_std, (2, 2))]
    
    else:
        for i in range(n):
            blocks += [ConvBlock(channels, W_std, b_std, stride)]
        
    return stax.serial(*blocks)
        
def VGG19_stride(class_num=class_num):
    
    return stax.serial(
        ConvGroup(n=2, channels=64 , stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=2, channels=128, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=4, channels=256, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=4, channels=512, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=4, channels=512, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        stax.Flatten(),
        stax.Dense(512), stax.Relu(do_backprop=True),
        stax.Dense(class_num))

def simple_net(class_num=class_num):
    return stax.serial(
        ConvGroup(n=3, channels=64 , stride=(1,1), W_std=onp.sqrt(2), b_std=0.0, last_stride=False),
        stax.Flatten(),
        stax.Dense(512, W_std=onp.sqrt(2)), stax.Relu(do_backprop=True),
        stax.Dense(class_num, W_std=onp.sqrt(2)))

In [9]:
# b = 0.18
# W = 1.76

# b_std = np.sqrt(b)
# W_std = np.sqrt(W)

In [10]:
# init_fn, apply_fn, kernel_fn = stax.serial(stax.Conv(64, (5, 5), (2, 2), padding="SAME", W_std=W_std, b_std=b_std), 
#                                            stax.Relu(do_backprop=True),
#                                            stax.Conv(64, (5, 5), (2, 2), padding="SAME", W_std=W_std, b_std=b_std), 
#                                            stax.Relu(do_backprop=True),
#                                            stax.Flatten(),
#                                            stax.Dense(384, W_std=W_std, b_std=b_std),
#                                            stax.Relu(do_backprop=True),
#                                            stax.Dense(192, W_std=W_std, b_std=b_std),
#                                            stax.Relu(do_backprop=True),
#                                            stax.Dense(class_num))

In [11]:
@jit
def l2_loss_v1(logits, labels, weighting=1):
    """
    Tensorflow version of L2 loss (without sqrt)
    """
    return np.sum(((logits - labels)**2) * weighting) / 2
    
@jit
def l2_loss_v2(logits, lables):
    """
    Normal L2 loss
    """
    return np.linalg.norm(logits - labels)

@jit
def cross_entropy_loss(logits, lables):
    return -np.mean(logsoftmax(logits) * lables)
    
@jit
def mse_loss(logits, lables):
    return 0.5 * np.mean((logits - lables) ** 2)

In [12]:
init_fn, apply_fn, kernel_fn = simple_net(class_num)

In [13]:
apply_fn = jit(apply_fn)

In [14]:
key = random.PRNGKey(0)
key, net_key = random.split(key)
_, params = init_fn(net_key, (-1, 32, 32, 3))

In [15]:
learning_rate = 1e0
# training_steps = 3200

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_update = jit(opt_update)

In [16]:
# loss = jit(lambda params, x, y: -np.mean(logsoftmax(apply_fn(params, x)) * y))
loss = lambda params, x, y: 0.5 * np.mean((apply_fn(params, x) - y) ** 2)
grad_loss = jit(lambda params, x, y: grad(loss)(params, x, y))

In [17]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = onp.random.permutation(len(a))
    return a[p], b[p]

In [18]:
opt_state = opt_init(params)

In [None]:
train_losses = []
train_accuracy = []

valid_losses = []
valid_accuracy = []

# valid = (x_valid, y_valid)
steps_per_epoch = train_size//batch_size

for i in range(epochs):
    train_epoch_losses = []
    train_epoch_accuracy = []
    
    valid_epoch_losses = []
    valid_epoch_accuracy = []
    
    x_train, y_train = shaffle(x_train, y_train)
    
    for batch in range(steps_per_epoch):
        
        _x_train = x_train[batch*batch_size:(batch+1)*batch_size]
        _y_train = y_train[batch*batch_size:(batch+1)*batch_size]
        
        params = get_params(opt_state)
        opt_state = opt_update(i*steps_per_epoch + batch, grad_loss(params, _x_train, _y_train), opt_state)
        
        train_epoch_losses.append(loss(params, _x_train, _y_train))
        valid_epoch_losses.append(loss(params, x_valid, y_valid))
        
        train_correctness = onp.argmax(apply_fn(params, _x_train), 1) == onp.argmax(_y_train, 1)
        train_epoch_accuracy.append(onp.average(train_correctness))
        
        valid_correctness = onp.argmax(apply_fn(params, x_valid), 1) == onp.argmax(y_valid, 1)
        valid_epoch_accuracy.append(onp.average(valid_correctness))
    
    print("epoch %3d: train loss %3.5f, valid loss %3.5f train acc %.5f valid acc %.5f"%\
          (i, 
           onp.average(train_epoch_losses), 
           onp.average(valid_epoch_losses), 
           onp.average(train_epoch_accuracy), 
           onp.average(valid_epoch_accuracy)))
    
    train_losses.append(onp.average(train_epoch_losses))
    train_accuracy.append(onp.average(train_epoch_accuracy))
    
    valid_losses.append(onp.average(valid_epoch_losses))
    valid_accuracy.append(onp.average(valid_epoch_accuracy))

epoch   0: train loss 0.05752, valid loss 0.05757 train acc 0.11738 valid acc 0.11790
epoch   1: train loss 0.04709, valid loss 0.04717 train acc 0.15482 valid acc 0.15461
epoch   2: train loss 0.04564, valid loss 0.04571 train acc 0.17729 valid acc 0.17667
epoch   3: train loss 0.04494, valid loss 0.04501 train acc 0.19047 valid acc 0.19307
epoch   4: train loss 0.04447, valid loss 0.04451 train acc 0.20238 valid acc 0.20489
epoch   5: train loss 0.04409, valid loss 0.04412 train acc 0.21222 valid acc 0.21398
epoch   6: train loss 0.04380, valid loss 0.04382 train acc 0.22020 valid acc 0.22363
epoch   7: train loss 0.04357, valid loss 0.04359 train acc 0.22933 valid acc 0.23001
epoch   8: train loss 0.04337, valid loss 0.04338 train acc 0.23764 valid acc 0.23589
epoch   9: train loss 0.04320, valid loss 0.04321 train acc 0.24327 valid acc 0.24170
epoch  10: train loss 0.04305, valid loss 0.04306 train acc 0.24813 valid acc 0.24651
epoch  11: train loss 0.04291, valid loss 0.04292 trai

epoch  96: train loss 0.03887, valid loss 0.03899 train acc 0.40622 valid acc 0.38131
epoch  97: train loss 0.03885, valid loss 0.03897 train acc 0.40684 valid acc 0.38153
epoch  98: train loss 0.03883, valid loss 0.03895 train acc 0.40758 valid acc 0.38174
epoch  99: train loss 0.03880, valid loss 0.03893 train acc 0.40727 valid acc 0.38236
epoch 100: train loss 0.03878, valid loss 0.03891 train acc 0.40840 valid acc 0.38270
epoch 101: train loss 0.03876, valid loss 0.03889 train acc 0.40909 valid acc 0.38327
epoch 102: train loss 0.03874, valid loss 0.03887 train acc 0.40962 valid acc 0.38317
epoch 103: train loss 0.03872, valid loss 0.03885 train acc 0.41007 valid acc 0.38445
epoch 104: train loss 0.03870, valid loss 0.03883 train acc 0.41153 valid acc 0.38458
epoch 105: train loss 0.03868, valid loss 0.03881 train acc 0.41100 valid acc 0.38481
epoch 106: train loss 0.03865, valid loss 0.03879 train acc 0.41231 valid acc 0.38538
epoch 107: train loss 0.03864, valid loss 0.03877 trai

epoch 192: train loss 0.03728, valid loss 0.03753 train acc 0.44429 valid acc 0.42396
epoch 193: train loss 0.03727, valid loss 0.03752 train acc 0.44598 valid acc 0.42420
epoch 194: train loss 0.03726, valid loss 0.03751 train acc 0.44564 valid acc 0.42444
epoch 195: train loss 0.03725, valid loss 0.03749 train acc 0.44660 valid acc 0.42467
epoch 196: train loss 0.03723, valid loss 0.03748 train acc 0.44600 valid acc 0.42540
epoch 197: train loss 0.03722, valid loss 0.03747 train acc 0.44731 valid acc 0.42529
epoch 198: train loss 0.03721, valid loss 0.03746 train acc 0.44636 valid acc 0.42561
epoch 199: train loss 0.03719, valid loss 0.03745 train acc 0.44771 valid acc 0.42608
epoch 200: train loss 0.03718, valid loss 0.03744 train acc 0.44822 valid acc 0.42660
epoch 201: train loss 0.03717, valid loss 0.03743 train acc 0.44827 valid acc 0.42684
epoch 202: train loss 0.03716, valid loss 0.03741 train acc 0.44911 valid acc 0.42726
epoch 203: train loss 0.03714, valid loss 0.03741 trai

epoch 288: train loss 0.03620, valid loss 0.03658 train acc 0.47036 valid acc 0.45051
epoch 289: train loss 0.03619, valid loss 0.03657 train acc 0.47082 valid acc 0.45037
epoch 290: train loss 0.03618, valid loss 0.03656 train acc 0.47160 valid acc 0.45095
epoch 291: train loss 0.03617, valid loss 0.03655 train acc 0.47211 valid acc 0.45140
epoch 292: train loss 0.03616, valid loss 0.03654 train acc 0.47200 valid acc 0.45156
epoch 293: train loss 0.03615, valid loss 0.03654 train acc 0.47249 valid acc 0.45128
epoch 294: train loss 0.03614, valid loss 0.03653 train acc 0.47218 valid acc 0.45182
epoch 295: train loss 0.03613, valid loss 0.03652 train acc 0.47193 valid acc 0.45212
epoch 296: train loss 0.03612, valid loss 0.03651 train acc 0.47249 valid acc 0.45228
epoch 297: train loss 0.03611, valid loss 0.03650 train acc 0.47289 valid acc 0.45244
epoch 298: train loss 0.03610, valid loss 0.03649 train acc 0.47278 valid acc 0.45245
epoch 299: train loss 0.03610, valid loss 0.03648 trai

epoch 384: train loss 0.03536, valid loss 0.03585 train acc 0.48944 valid acc 0.46947
epoch 385: train loss 0.03535, valid loss 0.03585 train acc 0.48976 valid acc 0.46949
epoch 386: train loss 0.03535, valid loss 0.03584 train acc 0.48984 valid acc 0.46995
epoch 387: train loss 0.03534, valid loss 0.03583 train acc 0.49051 valid acc 0.46992
epoch 388: train loss 0.03533, valid loss 0.03582 train acc 0.49142 valid acc 0.47000
epoch 389: train loss 0.03532, valid loss 0.03582 train acc 0.49087 valid acc 0.47042
epoch 390: train loss 0.03531, valid loss 0.03581 train acc 0.49087 valid acc 0.47027
epoch 391: train loss 0.03531, valid loss 0.03581 train acc 0.49033 valid acc 0.47083
epoch 392: train loss 0.03530, valid loss 0.03580 train acc 0.49053 valid acc 0.47108
epoch 393: train loss 0.03529, valid loss 0.03579 train acc 0.49136 valid acc 0.47188
epoch 394: train loss 0.03528, valid loss 0.03579 train acc 0.49227 valid acc 0.47141
epoch 395: train loss 0.03528, valid loss 0.03578 trai

epoch 480: train loss 0.03467, valid loss 0.03526 train acc 0.50576 valid acc 0.48399
epoch 481: train loss 0.03467, valid loss 0.03525 train acc 0.50567 valid acc 0.48354
epoch 482: train loss 0.03466, valid loss 0.03525 train acc 0.50571 valid acc 0.48394
epoch 483: train loss 0.03465, valid loss 0.03524 train acc 0.50582 valid acc 0.48383
epoch 484: train loss 0.03464, valid loss 0.03524 train acc 0.50660 valid acc 0.48387
epoch 485: train loss 0.03464, valid loss 0.03523 train acc 0.50647 valid acc 0.48408
epoch 486: train loss 0.03463, valid loss 0.03523 train acc 0.50622 valid acc 0.48399
epoch 487: train loss 0.03462, valid loss 0.03522 train acc 0.50689 valid acc 0.48457
epoch 488: train loss 0.03462, valid loss 0.03522 train acc 0.50598 valid acc 0.48452
epoch 489: train loss 0.03461, valid loss 0.03521 train acc 0.50682 valid acc 0.48423
epoch 490: train loss 0.03461, valid loss 0.03520 train acc 0.50698 valid acc 0.48465
epoch 491: train loss 0.03460, valid loss 0.03520 trai

epoch 576: train loss 0.03408, valid loss 0.03474 train acc 0.51907 valid acc 0.49604
epoch 577: train loss 0.03407, valid loss 0.03474 train acc 0.51900 valid acc 0.49642
epoch 578: train loss 0.03406, valid loss 0.03473 train acc 0.51996 valid acc 0.49623
epoch 579: train loss 0.03406, valid loss 0.03472 train acc 0.51944 valid acc 0.49646
epoch 580: train loss 0.03405, valid loss 0.03472 train acc 0.52018 valid acc 0.49639
epoch 581: train loss 0.03404, valid loss 0.03472 train acc 0.51882 valid acc 0.49643
epoch 582: train loss 0.03404, valid loss 0.03471 train acc 0.52022 valid acc 0.49669
epoch 583: train loss 0.03403, valid loss 0.03471 train acc 0.51960 valid acc 0.49716
epoch 584: train loss 0.03403, valid loss 0.03470 train acc 0.51953 valid acc 0.49689
epoch 585: train loss 0.03402, valid loss 0.03469 train acc 0.52067 valid acc 0.49718
epoch 586: train loss 0.03402, valid loss 0.03469 train acc 0.52076 valid acc 0.49735
epoch 587: train loss 0.03401, valid loss 0.03468 trai

epoch 672: train loss 0.03353, valid loss 0.03428 train acc 0.53204 valid acc 0.50865
epoch 673: train loss 0.03353, valid loss 0.03427 train acc 0.53053 valid acc 0.50879
epoch 674: train loss 0.03352, valid loss 0.03427 train acc 0.53120 valid acc 0.50878
epoch 675: train loss 0.03351, valid loss 0.03426 train acc 0.53104 valid acc 0.50883
epoch 676: train loss 0.03351, valid loss 0.03426 train acc 0.53204 valid acc 0.50892
epoch 677: train loss 0.03350, valid loss 0.03425 train acc 0.53140 valid acc 0.50918
epoch 678: train loss 0.03350, valid loss 0.03425 train acc 0.53084 valid acc 0.50887
epoch 679: train loss 0.03350, valid loss 0.03424 train acc 0.53171 valid acc 0.50920
epoch 680: train loss 0.03349, valid loss 0.03424 train acc 0.53267 valid acc 0.50950
epoch 681: train loss 0.03348, valid loss 0.03424 train acc 0.53213 valid acc 0.50938
epoch 682: train loss 0.03348, valid loss 0.03423 train acc 0.53280 valid acc 0.50931
epoch 683: train loss 0.03347, valid loss 0.03423 trai

# loss

# attack algorithms

In [15]:
def fast_gradient_method(model_fn, kernel_fn, obj_fn, grads_fn, x_train=None, y_train=None, x_test=None, 
                         y=None, t=None, loss_weighting=None, fx_train_0=0., fx_test_0=0., eps=0.3, 
                         norm=np.inf, clip_min=None, clip_max=None, targeted=False):
    if norm not in [np.inf, 2]:
        raise ValueError("Norm order must be either np.inf or 2.")
        
    x = x_test
        
    # test independent
    if obj_fn == 'untargeted':
        grads = grads_fn(x_train, x_test, y_train, y, kernel_fn, t)
        
    else:
        raise ValueError("Objective function must be either train(ntk_train_train) or test(predict_fn)")

    axis = list(range(1, len(grads.shape)))
    eps_div = 1e-12
    
    if norm == np.inf:
        perturbation = eps * np.sign(grads)
    elif norm == 1:
        raise NotImplementedError("L_1 norm has not been implemented yet.")
    elif norm == 2:
        square = np.maximum(eps_div, np.sum(np.square(grads), axis=axis, keepdims=True))
        perturbation = grads / np.sqrt(square)
    
    # TODO
    adv_x = x + perturbation
    
    # If clipping is needed, reset all values outside of [clip_min, clip_max]
    if (clip_min is not None) or (clip_max is not None):
        # We don't currently support one-sided clipping
        assert clip_min is not None and clip_max is not None
        adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
    
    return adv_x

In [16]:
def iter_fast_gradient_method(model_fn, kernel_fn, obj_fn, grads_fn, x_train=None, y_train=None,
                               x_test=None, y=None, t=None, loss_weighting=None, fx_train_0=0., fx_test_0=0., 
                               eps=0.3, eps_iter=0.03, nb_iter=10, norm=np.inf, clip_min=None, clip_max=None, 
                               targeted=False, rand_init=None, rand_minmax=0.3):

    assert eps_iter <= eps, (eps_iter, eps)
    if norm == 1:
        raise NotImplementedError("It's not clear that FGM is a good inner loop"
                                  " step for PGD when norm=1, because norm=1 FGM "
                                  " changes only one pixel at a time. We need "
                                  " to rigorously test a strong norm=1 PGD "
                                  "before enabling this feature.")
    if norm not in [np.inf, 2]:
        raise ValueError("Norm order must be either np.inf or 2.")
        
    x = x_test
    
    # Initialize loop variables
    if rand_init:
        rand_minmax = eps
        eta = random.uniform(new_key, x.shape, minval=-rand_minmax, maxval=rand_minmax)
    else:
        eta = np.zeros_like(x)

    # Clip eta
    eta = clip_eta(eta, norm, eps)
    adv_x = x + eta
    if clip_min is not None or clip_max is not None:
        adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
        
    for i in range(nb_iter):
        adv_x = fast_gradient_method(model_fn, kernel_fn, obj_fn, grads_fn, x_train, y_train, adv_x, 
                                        y, t, loss_weighting, fx_train_0, fx_test_0, eps_iter, norm, 
                                        clip_min, clip_max, targeted)

        # Clipping perturbation eta to norm norm ball
        eta = adv_x - x
        eta = clip_eta(eta, norm, eps)
        adv_x = x + eta

        # Redo the clipping.
        # FGM already did it, but subtracting and re-adding eta can add some
        # small numerical error.
        if clip_min is not None or clip_max is not None:
            adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
    
    return adv_x

# Attack Hyperparameters

In [17]:
if DATASET == 'mnist':
    eps = 0.3
    eps_iter_10 = (eps/10)*1.1
    eps_iter_100 = (eps/100)*1.1
    eps_iter_1000 = (eps/1000)*1.1
    
elif DATASET == 'cifar10':
    eps = 0.03
    eps_iter_10 = (eps/10)*1.1
    eps_iter_100 = (eps/100)*1.1

In [18]:
def evaluate_accuracy(x_train, x_test, y_test, model_fn, kernel_fn, t=None, attack_type=None, ntk_train_train=None):
    
    y_train_predict, y_test_predict = model_fn(kernel_fn, x_train, x_test, 
                                               t=t, ntk_train_train=ntk_train_train)
    
    selected_table = correct(y_test_predict, y_test)
    print("Accuray({:s}): {:.2f}".format(attack_type, onp.mean(selected_table)))
    
    return selected_table

In [19]:
def evaluate_robustness(x_train, x_test, y_test, model_fn, kernel_fn, selected_table, t=None, 
                        attack_type=None, ntk_train_train=None):
    
    y_train_predict, y_test_predict = model_fn(kernel_fn, x_train, x_test,
                                               t=t, ntk_train_train=ntk_train_train)
    
    y_test_predict = onp.asarray(y_test_predict)
    y_test_predict_select = y_test_predict[onp.asarray(selected_table)]
    y_test_select = y_test[onp.asarray(selected_table)]
    print("Robustness({:s}): {:.2f}".format(attack_type, onp.mean(correct(y_test_predict_select, y_test_select))))
    
    return

# adv_x generation