In [1]:
import os
os.environ["CUDA_VISIBLE_DEVICES"]="0"
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import numpy as onp
import jax
import jax.numpy as np

from jax import lax, random
from jax.api import grad, jit, vmap
from jax.config import config
from jax.experimental import optimizers
from jax.experimental.stax import logsoftmax

config.update('jax_enable_x64', True)

from neural_tangents import stax

from functools import partial

# Attacking
from cleverhans.utils import clip_eta, one_hot

# Plotting
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import *

sns.set_style(style='white')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

"""
diag_reg:
    a scalar representing the strength of the diagonal regularization for
    `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
    Cholesky factorization or eigendecomposition.
"""
diag_reg = 1e-5

# data

In [38]:
DATASET = 'cifar10'
class_num   = 10
image_shape = None

train_size = 4096
valid_size = 512
test_size = 128
test_batch_size  = 16
eps = 0.03

batch_size = 256
eps = 0.03
epochs = 1000

if DATASET =='mnist':
    image_shape = (28, 28, 1)
elif DATASET == 'cifar10':
    image_shape = (32, 32, 3)

In [16]:
x_train_all, y_train_all, x_test_all, y_test_all = tuple(onp.array(x) for x in get_dataset(DATASET, None, None, 
                                                                                  do_flatten_and_normalize=False))

In [17]:
# shuffle
seed = 0
x_train_all, y_train_all = shaffle(x_train_all, y_train_all, seed)

In [18]:
# down sample
x_train = x_train_all[:train_size]
y_train = y_train_all[:train_size]

x_valid = x_train_all[train_size:train_size + valid_size]
y_valid = y_train_all[train_size:train_size + valid_size]

x_test = x_test_all[:test_size]
y_test = y_test_all[:test_size]

In [19]:
x_train, x_valid, x_test = x_train.reshape((-1, *image_shape)), x_valid.reshape((-1, *image_shape)), x_test.reshape((-1, *image_shape))

# model

In [20]:
def correct(mean, ys):
    return onp.argmax(mean, axis=-1) == onp.argmax(ys, axis=-1)

In [22]:
def ConvBlock(channels, W_std, b_std, strides=(1,1)):
    return stax.serial(stax.Conv(out_chan=channels, filter_shape=(3,3), strides=strides, padding='SAME',
                                 W_std=W_std, b_std=b_std), 
                       stax.Relu(do_backprop=True))

def ConvGroup(n, channels, stride, W_std, b_std, last_stride=False):
    blocks = []
    if last_stride:
        for i in range(n-1):
            blocks += [ConvBlock(channels, W_std, b_std, stride)]
        blocks += [ConvBlock(channels, W_std, b_std, (2, 2))]
    
    else:
        for i in range(n):
            blocks += [ConvBlock(channels, W_std, b_std, stride)]
        
    return stax.serial(*blocks)
        
def VGG19_stride(class_num=class_num):
    
    return stax.serial(
        ConvGroup(n=2, channels=64 , stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=2, channels=128, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=4, channels=256, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=4, channels=512, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        ConvGroup(n=4, channels=512, stride=(1,1), W_std=0.1, b_std=0.18, last_stride=True),
        stax.Flatten(),
        stax.Dense(4096), stax.Relu(do_backprop=True),
        stax.Dense(4096), stax.Relu(do_backprop=True),
        stax.Dense(class_num))

def simple_net(class_num=class_num):
    return stax.serial(
        ConvGroup(n=3, channels=64 , stride=(1,1), W_std=onp.sqrt(2), b_std=0.0, last_stride=False),
        stax.Flatten(),
        stax.Dense(4096, W_std=onp.sqrt(2)), stax.Relu(do_backprop=True),
        stax.Dense(class_num, W_std=onp.sqrt(2)))

In [27]:
@jit
def l2_loss_v1(logits, labels, weighting=1):
    """
    Tensorflow version of L2 loss (without sqrt)
    """
    return np.sum(((logits - labels)**2) * weighting) / 2
    
@jit
def l2_loss_v2(logits, lables):
    """
    Normal L2 loss
    """
    return np.linalg.norm(logits - labels)

@jit
def cross_entropy_loss(logits, lables):
    return -np.mean(logsoftmax(logits) * lables)
    
@jit
def mse_loss(logits, lables):
    return 0.5 * np.mean((logits - lables) ** 2)

In [40]:
init_fn, apply_fn, kernel_fn = simple_net(class_num)

In [41]:
apply_fn = jit(apply_fn)

In [42]:
key = random.PRNGKey(88888)
key, net_key = random.split(key)
_, params = init_fn(net_key, (-1, 32, 32, 3))

In [43]:
learning_rate = 0.001
# training_steps = 3200

opt_init, opt_update, get_params = optimizers.sgd(learning_rate)
opt_update = jit(opt_update)

In [44]:
loss = jit(lambda params, x, y: -np.mean(logsoftmax(apply_fn(params, x)) * y))
grad_loss = jit(lambda state, x, y: grad(loss)(get_params(state), x, y))

In [45]:
def unison_shuffled_copies(a, b):
    assert len(a) == len(b)
    p = onp.random.permutation(len(a))
    return a[p], b[p]

In [46]:
opt_state = opt_init(params)

In [47]:
train_losses = []
train_accuracy = []

valid_losses = []
valid_accuracy = []

valid = (x_valid, y_valid)

for i in range(epochs):
    train_epoch_losses = []
    train_epoch_accuracy = []
    
    valid_epoch_losses = []
    valid_epoch_accuracy = []
    
    for batch in range(train_size//batch_size):
        
        train = (x_train[batch*batch_size:batch*batch_size+batch_size], 
                 y_train[batch*batch_size:batch*batch_size+batch_size])
        
        opt_state = opt_update(i*(train_size//batch_size) + batch, grad_loss(opt_state, *train), opt_state)
        
        train_epoch_losses.append(loss(get_params(opt_state), *train))
        valid_epoch_losses.append(loss(get_params(opt_state), *valid))
        
        train_correctness = onp.argmax(apply_fn(get_params(opt_state), train[0]), 1) == onp.argmax(train[1], 1)
        train_epoch_accuracy.append(onp.average(train_correctness))
        
        valid_correctness = onp.argmax(apply_fn(get_params(opt_state), valid[0]), 1) == onp.argmax(valid[1], 1)
        valid_epoch_accuracy.append(onp.average(valid_correctness))
    
    print("epoch %3d: train loss %3.5f, valid loss %3.5f train acc %.5f valid acc %.5f"%\
          (i, onp.average(train_epoch_losses), 
           onp.average(valid_epoch_losses), 
           onp.average(train_epoch_accuracy), 
           onp.average(valid_epoch_accuracy)))
    
    train_losses.append(onp.average(train_epoch_losses))
    train_accuracy.append(onp.average(train_epoch_accuracy))
    
    valid_losses.append(onp.average(valid_epoch_losses))
    valid_accuracy.append(onp.average(valid_epoch_accuracy))
    
    x_train, y_train = unison_shuffled_copies(x_train, y_train)

epoch   0: train loss 0.24208, valid loss 0.24338 train acc 0.09473 valid acc 0.08008
epoch   1: train loss 0.24208, valid loss 0.24337 train acc 0.09473 valid acc 0.08008
epoch   2: train loss 0.24207, valid loss 0.24337 train acc 0.09473 valid acc 0.08008
epoch   3: train loss 0.24206, valid loss 0.24336 train acc 0.09473 valid acc 0.08008
epoch   4: train loss 0.24206, valid loss 0.24335 train acc 0.09473 valid acc 0.08008
epoch   5: train loss 0.24205, valid loss 0.24335 train acc 0.09473 valid acc 0.08008
epoch   6: train loss 0.24204, valid loss 0.24334 train acc 0.09473 valid acc 0.08008
epoch   7: train loss 0.24204, valid loss 0.24333 train acc 0.09473 valid acc 0.08008
epoch   8: train loss 0.24203, valid loss 0.24333 train acc 0.09497 valid acc 0.08008
epoch   9: train loss 0.24202, valid loss 0.24332 train acc 0.09497 valid acc 0.08008
epoch  10: train loss 0.24202, valid loss 0.24331 train acc 0.09497 valid acc 0.08008
epoch  11: train loss 0.24201, valid loss 0.24331 trai

epoch  96: train loss 0.24145, valid loss 0.24276 train acc 0.09521 valid acc 0.08008
epoch  97: train loss 0.24144, valid loss 0.24275 train acc 0.09521 valid acc 0.08008
epoch  98: train loss 0.24144, valid loss 0.24274 train acc 0.09521 valid acc 0.08008
epoch  99: train loss 0.24143, valid loss 0.24274 train acc 0.09521 valid acc 0.08008
epoch 100: train loss 0.24143, valid loss 0.24273 train acc 0.09521 valid acc 0.08008
epoch 101: train loss 0.24142, valid loss 0.24273 train acc 0.09521 valid acc 0.08008
epoch 102: train loss 0.24141, valid loss 0.24272 train acc 0.09521 valid acc 0.08008
epoch 103: train loss 0.24141, valid loss 0.24271 train acc 0.09521 valid acc 0.08008
epoch 104: train loss 0.24140, valid loss 0.24271 train acc 0.09521 valid acc 0.08008
epoch 105: train loss 0.24139, valid loss 0.24270 train acc 0.09521 valid acc 0.08008
epoch 106: train loss 0.24139, valid loss 0.24270 train acc 0.09521 valid acc 0.08008
epoch 107: train loss 0.24138, valid loss 0.24269 trai

epoch 192: train loss 0.24086, valid loss 0.24218 train acc 0.09497 valid acc 0.08008
epoch 193: train loss 0.24086, valid loss 0.24217 train acc 0.09497 valid acc 0.08008
epoch 194: train loss 0.24085, valid loss 0.24217 train acc 0.09497 valid acc 0.08008
epoch 195: train loss 0.24085, valid loss 0.24216 train acc 0.09497 valid acc 0.08008
epoch 196: train loss 0.24084, valid loss 0.24216 train acc 0.09497 valid acc 0.08008
epoch 197: train loss 0.24083, valid loss 0.24215 train acc 0.09497 valid acc 0.08008
epoch 198: train loss 0.24083, valid loss 0.24214 train acc 0.09497 valid acc 0.08008
epoch 199: train loss 0.24082, valid loss 0.24214 train acc 0.09497 valid acc 0.08008
epoch 200: train loss 0.24082, valid loss 0.24213 train acc 0.09497 valid acc 0.08008
epoch 201: train loss 0.24081, valid loss 0.24213 train acc 0.09497 valid acc 0.08008
epoch 202: train loss 0.24081, valid loss 0.24212 train acc 0.09497 valid acc 0.08008
epoch 203: train loss 0.24080, valid loss 0.24212 trai

epoch 288: train loss 0.24032, valid loss 0.24164 train acc 0.09424 valid acc 0.08008
epoch 289: train loss 0.24031, valid loss 0.24164 train acc 0.09424 valid acc 0.08008
epoch 290: train loss 0.24031, valid loss 0.24163 train acc 0.09424 valid acc 0.08008
epoch 291: train loss 0.24030, valid loss 0.24162 train acc 0.09424 valid acc 0.08008
epoch 292: train loss 0.24030, valid loss 0.24162 train acc 0.09424 valid acc 0.08008
epoch 293: train loss 0.24029, valid loss 0.24161 train acc 0.09424 valid acc 0.08008
epoch 294: train loss 0.24029, valid loss 0.24161 train acc 0.09424 valid acc 0.08008
epoch 295: train loss 0.24028, valid loss 0.24160 train acc 0.09424 valid acc 0.08008
epoch 296: train loss 0.24028, valid loss 0.24160 train acc 0.09424 valid acc 0.08008
epoch 297: train loss 0.24027, valid loss 0.24159 train acc 0.09424 valid acc 0.08008
epoch 298: train loss 0.24026, valid loss 0.24159 train acc 0.09424 valid acc 0.08008
epoch 299: train loss 0.24026, valid loss 0.24158 trai

epoch 384: train loss 0.23981, valid loss 0.24114 train acc 0.09399 valid acc 0.08008
epoch 385: train loss 0.23981, valid loss 0.24113 train acc 0.09399 valid acc 0.08008
epoch 386: train loss 0.23980, valid loss 0.24113 train acc 0.09399 valid acc 0.08008
epoch 387: train loss 0.23980, valid loss 0.24112 train acc 0.09399 valid acc 0.08008
epoch 388: train loss 0.23979, valid loss 0.24112 train acc 0.09399 valid acc 0.08008
epoch 389: train loss 0.23979, valid loss 0.24111 train acc 0.09399 valid acc 0.08008
epoch 390: train loss 0.23978, valid loss 0.24111 train acc 0.09399 valid acc 0.08008
epoch 391: train loss 0.23978, valid loss 0.24110 train acc 0.09399 valid acc 0.08008
epoch 392: train loss 0.23977, valid loss 0.24110 train acc 0.09399 valid acc 0.08008
epoch 393: train loss 0.23977, valid loss 0.24109 train acc 0.09399 valid acc 0.08008
epoch 394: train loss 0.23976, valid loss 0.24109 train acc 0.09375 valid acc 0.08008
epoch 395: train loss 0.23976, valid loss 0.24108 trai

epoch 480: train loss 0.23934, valid loss 0.24066 train acc 0.09302 valid acc 0.08008
epoch 481: train loss 0.23933, valid loss 0.24066 train acc 0.09302 valid acc 0.08008
epoch 482: train loss 0.23933, valid loss 0.24065 train acc 0.09302 valid acc 0.08008
epoch 483: train loss 0.23932, valid loss 0.24065 train acc 0.09302 valid acc 0.08008
epoch 484: train loss 0.23932, valid loss 0.24064 train acc 0.09302 valid acc 0.08008
epoch 485: train loss 0.23931, valid loss 0.24064 train acc 0.09326 valid acc 0.08008
epoch 486: train loss 0.23931, valid loss 0.24063 train acc 0.09326 valid acc 0.08008
epoch 487: train loss 0.23930, valid loss 0.24063 train acc 0.09326 valid acc 0.08008
epoch 488: train loss 0.23930, valid loss 0.24062 train acc 0.09326 valid acc 0.08008
epoch 489: train loss 0.23929, valid loss 0.24062 train acc 0.09326 valid acc 0.08008
epoch 490: train loss 0.23929, valid loss 0.24061 train acc 0.09326 valid acc 0.08008
epoch 491: train loss 0.23929, valid loss 0.24061 trai

epoch 576: train loss 0.23889, valid loss 0.24022 train acc 0.09302 valid acc 0.08008
epoch 577: train loss 0.23889, valid loss 0.24022 train acc 0.09302 valid acc 0.08008
epoch 578: train loss 0.23888, valid loss 0.24021 train acc 0.09302 valid acc 0.08008
epoch 579: train loss 0.23888, valid loss 0.24021 train acc 0.09302 valid acc 0.08008
epoch 580: train loss 0.23888, valid loss 0.24020 train acc 0.09302 valid acc 0.08008
epoch 581: train loss 0.23887, valid loss 0.24020 train acc 0.09302 valid acc 0.08008
epoch 582: train loss 0.23887, valid loss 0.24019 train acc 0.09302 valid acc 0.08008
epoch 583: train loss 0.23886, valid loss 0.24019 train acc 0.09302 valid acc 0.08008
epoch 584: train loss 0.23886, valid loss 0.24018 train acc 0.09302 valid acc 0.08008
epoch 585: train loss 0.23885, valid loss 0.24018 train acc 0.09302 valid acc 0.08008
epoch 586: train loss 0.23885, valid loss 0.24018 train acc 0.09302 valid acc 0.08008
epoch 587: train loss 0.23884, valid loss 0.24017 trai

epoch 672: train loss 0.23848, valid loss 0.23981 train acc 0.09302 valid acc 0.08008
epoch 673: train loss 0.23847, valid loss 0.23980 train acc 0.09302 valid acc 0.08008
epoch 674: train loss 0.23847, valid loss 0.23980 train acc 0.09302 valid acc 0.08008
epoch 675: train loss 0.23846, valid loss 0.23979 train acc 0.09302 valid acc 0.08008
epoch 676: train loss 0.23846, valid loss 0.23979 train acc 0.09302 valid acc 0.08008
epoch 677: train loss 0.23846, valid loss 0.23979 train acc 0.09302 valid acc 0.08008
epoch 678: train loss 0.23845, valid loss 0.23978 train acc 0.09302 valid acc 0.08008
epoch 679: train loss 0.23845, valid loss 0.23978 train acc 0.09302 valid acc 0.08008
epoch 680: train loss 0.23844, valid loss 0.23977 train acc 0.09302 valid acc 0.08008
epoch 681: train loss 0.23844, valid loss 0.23977 train acc 0.09302 valid acc 0.08008
epoch 682: train loss 0.23843, valid loss 0.23976 train acc 0.09326 valid acc 0.08008
epoch 683: train loss 0.23843, valid loss 0.23976 trai

epoch 768: train loss 0.23809, valid loss 0.23942 train acc 0.09326 valid acc 0.08008
epoch 769: train loss 0.23808, valid loss 0.23941 train acc 0.09326 valid acc 0.08008
epoch 770: train loss 0.23808, valid loss 0.23941 train acc 0.09326 valid acc 0.08008
epoch 771: train loss 0.23807, valid loss 0.23941 train acc 0.09326 valid acc 0.08008
epoch 772: train loss 0.23807, valid loss 0.23940 train acc 0.09326 valid acc 0.08008
epoch 773: train loss 0.23807, valid loss 0.23940 train acc 0.09326 valid acc 0.08008
epoch 774: train loss 0.23806, valid loss 0.23940 train acc 0.09326 valid acc 0.08008
epoch 775: train loss 0.23806, valid loss 0.23939 train acc 0.09326 valid acc 0.08008
epoch 776: train loss 0.23805, valid loss 0.23939 train acc 0.09326 valid acc 0.08008
epoch 777: train loss 0.23805, valid loss 0.23938 train acc 0.09326 valid acc 0.08008
epoch 778: train loss 0.23805, valid loss 0.23938 train acc 0.09326 valid acc 0.08008
epoch 779: train loss 0.23804, valid loss 0.23938 trai

epoch 864: train loss 0.23772, valid loss 0.23906 train acc 0.09253 valid acc 0.08008
epoch 865: train loss 0.23771, valid loss 0.23905 train acc 0.09253 valid acc 0.08008
epoch 866: train loss 0.23771, valid loss 0.23905 train acc 0.09253 valid acc 0.08008
epoch 867: train loss 0.23771, valid loss 0.23904 train acc 0.09253 valid acc 0.08008
epoch 868: train loss 0.23770, valid loss 0.23904 train acc 0.09229 valid acc 0.08008
epoch 869: train loss 0.23770, valid loss 0.23904 train acc 0.09229 valid acc 0.08008
epoch 870: train loss 0.23770, valid loss 0.23903 train acc 0.09229 valid acc 0.08008
epoch 871: train loss 0.23769, valid loss 0.23903 train acc 0.09229 valid acc 0.08008
epoch 872: train loss 0.23769, valid loss 0.23903 train acc 0.09229 valid acc 0.08008
epoch 873: train loss 0.23768, valid loss 0.23902 train acc 0.09229 valid acc 0.08008
epoch 874: train loss 0.23768, valid loss 0.23902 train acc 0.09253 valid acc 0.08008
epoch 875: train loss 0.23768, valid loss 0.23902 trai

epoch 960: train loss 0.23737, valid loss 0.23871 train acc 0.09204 valid acc 0.08008
epoch 961: train loss 0.23737, valid loss 0.23871 train acc 0.09204 valid acc 0.08008
epoch 962: train loss 0.23736, valid loss 0.23871 train acc 0.09229 valid acc 0.08008
epoch 963: train loss 0.23736, valid loss 0.23870 train acc 0.09229 valid acc 0.08008
epoch 964: train loss 0.23736, valid loss 0.23870 train acc 0.09229 valid acc 0.08008
epoch 965: train loss 0.23735, valid loss 0.23870 train acc 0.09229 valid acc 0.08008
epoch 966: train loss 0.23735, valid loss 0.23869 train acc 0.09229 valid acc 0.08008
epoch 967: train loss 0.23735, valid loss 0.23869 train acc 0.09229 valid acc 0.08008
epoch 968: train loss 0.23734, valid loss 0.23869 train acc 0.09229 valid acc 0.08008
epoch 969: train loss 0.23734, valid loss 0.23868 train acc 0.09229 valid acc 0.08008
epoch 970: train loss 0.23734, valid loss 0.23868 train acc 0.09229 valid acc 0.08008
epoch 971: train loss 0.23733, valid loss 0.23868 trai

# loss

# attack algorithms

In [15]:
def fast_gradient_method(model_fn, kernel_fn, obj_fn, grads_fn, x_train=None, y_train=None, x_test=None, 
                         y=None, t=None, loss_weighting=None, fx_train_0=0., fx_test_0=0., eps=0.3, 
                         norm=np.inf, clip_min=None, clip_max=None, targeted=False):
    if norm not in [np.inf, 2]:
        raise ValueError("Norm order must be either np.inf or 2.")
        
    x = x_test
        
    # test independent
    if obj_fn == 'untargeted':
        grads = grads_fn(x_train, x_test, y_train, y, kernel_fn, t)
        
    else:
        raise ValueError("Objective function must be either train(ntk_train_train) or test(predict_fn)")

    axis = list(range(1, len(grads.shape)))
    eps_div = 1e-12
    
    if norm == np.inf:
        perturbation = eps * np.sign(grads)
    elif norm == 1:
        raise NotImplementedError("L_1 norm has not been implemented yet.")
    elif norm == 2:
        square = np.maximum(eps_div, np.sum(np.square(grads), axis=axis, keepdims=True))
        perturbation = grads / np.sqrt(square)
    
    # TODO
    adv_x = x + perturbation
    
    # If clipping is needed, reset all values outside of [clip_min, clip_max]
    if (clip_min is not None) or (clip_max is not None):
        # We don't currently support one-sided clipping
        assert clip_min is not None and clip_max is not None
        adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
    
    return adv_x

In [16]:
def iter_fast_gradient_method(model_fn, kernel_fn, obj_fn, grads_fn, x_train=None, y_train=None,
                               x_test=None, y=None, t=None, loss_weighting=None, fx_train_0=0., fx_test_0=0., 
                               eps=0.3, eps_iter=0.03, nb_iter=10, norm=np.inf, clip_min=None, clip_max=None, 
                               targeted=False, rand_init=None, rand_minmax=0.3):

    assert eps_iter <= eps, (eps_iter, eps)
    if norm == 1:
        raise NotImplementedError("It's not clear that FGM is a good inner loop"
                                  " step for PGD when norm=1, because norm=1 FGM "
                                  " changes only one pixel at a time. We need "
                                  " to rigorously test a strong norm=1 PGD "
                                  "before enabling this feature.")
    if norm not in [np.inf, 2]:
        raise ValueError("Norm order must be either np.inf or 2.")
        
    x = x_test
    
    # Initialize loop variables
    if rand_init:
        rand_minmax = eps
        eta = random.uniform(new_key, x.shape, minval=-rand_minmax, maxval=rand_minmax)
    else:
        eta = np.zeros_like(x)

    # Clip eta
    eta = clip_eta(eta, norm, eps)
    adv_x = x + eta
    if clip_min is not None or clip_max is not None:
        adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
        
    for i in range(nb_iter):
        adv_x = fast_gradient_method(model_fn, kernel_fn, obj_fn, grads_fn, x_train, y_train, adv_x, 
                                        y, t, loss_weighting, fx_train_0, fx_test_0, eps_iter, norm, 
                                        clip_min, clip_max, targeted)

        # Clipping perturbation eta to norm norm ball
        eta = adv_x - x
        eta = clip_eta(eta, norm, eps)
        adv_x = x + eta

        # Redo the clipping.
        # FGM already did it, but subtracting and re-adding eta can add some
        # small numerical error.
        if clip_min is not None or clip_max is not None:
            adv_x = np.clip(adv_x, a_min=clip_min, a_max=clip_max)
    
    return adv_x

# Attack Hyperparameters

In [17]:
if DATASET == 'mnist':
    eps = 0.3
    eps_iter_10 = (eps/10)*1.1
    eps_iter_100 = (eps/100)*1.1
    eps_iter_1000 = (eps/1000)*1.1
    
elif DATASET == 'cifar10':
    eps = 0.03
    eps_iter_10 = (eps/10)*1.1
    eps_iter_100 = (eps/100)*1.1

In [18]:
def evaluate_accuracy(x_train, x_test, y_test, model_fn, kernel_fn, t=None, attack_type=None, ntk_train_train=None):
    
    y_train_predict, y_test_predict = model_fn(kernel_fn, x_train, x_test, 
                                               t=t, ntk_train_train=ntk_train_train)
    
    selected_table = correct(y_test_predict, y_test)
    print("Accuray({:s}): {:.2f}".format(attack_type, onp.mean(selected_table)))
    
    return selected_table

In [19]:
def evaluate_robustness(x_train, x_test, y_test, model_fn, kernel_fn, selected_table, t=None, 
                        attack_type=None, ntk_train_train=None):
    
    y_train_predict, y_test_predict = model_fn(kernel_fn, x_train, x_test,
                                               t=t, ntk_train_train=ntk_train_train)
    
    y_test_predict = onp.asarray(y_test_predict)
    y_test_predict_select = y_test_predict[onp.asarray(selected_table)]
    y_test_select = y_test[onp.asarray(selected_table)]
    print("Robustness({:s}): {:.2f}".format(attack_type, onp.mean(correct(y_test_predict_select, y_test_select))))
    
    return

# adv_x generation