In [1]:
import os

# 1e2 1e3 1e4 1e5
gpu_id = 0
time = [1e2]
lamb = 10**0

os.environ["CUDA_VISIBLE_DEVICES"] = str(gpu_id)
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"

import numpy as onp
import jax
import jax.numpy as np

from jax import lax, random
from jax.api import grad, jit, vmap
from jax.config import config
from jax.experimental import optimizers
from jax.experimental.stax import logsoftmax

config.update('jax_enable_x64', True)

from neural_tangents import stax

from functools import partial

# Attacking
from cleverhans.utils import clip_eta, one_hot

# Plotting
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import *

sns.set_style(style='white')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

"""
diag_reg:
    a scalar representing the strength of the diagonal regularization for
    `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
    Cholesky factorization or eigendecomposition.
"""
diag_reg = 1e-5

# data

In [2]:
DATASET = 'imagenet'
class_num   = 2
image_shape = (224, 224, 3)

train_size = None
test_size = 100
test_batch_size = 1
eps = 0.03

In [3]:
def one_hot(x, k, dtype=np.float64):
    """Create a one-hot encoding of x of size k."""
    return onp.array(x[:, None] == onp.arange(k), dtype)

In [4]:
x_train = onp.load('../../Jimmy/ntk_nngp/source/dataset/imagenet_x_train.npy')[:2000]
y_train = onp.load('../../Jimmy/ntk_nngp/source/dataset/imagenet_y_train.npy')[:2000]

x_test = onp.load('../../Jimmy/ntk_nngp/source/dataset/imagenet_x_test.npy')
y_test = onp.load('../../Jimmy/ntk_nngp/source/dataset/imagenet_y_test.npy')

In [5]:
# shuffle
seed = 0
x_train, y_train = shaffle(x_train, y_train, seed)

# model

In [6]:
def correct(mean, ys):
    return onp.argmax(mean, axis=-1) == onp.argmax(ys, axis=-1)

In [7]:
b = 0.18
W = 1.76

b_std = np.sqrt(b)
W_std = np.sqrt(W)

In [8]:
def ConvBlock(channels, W_std, b_std, strides=(1,1)):
    return stax.serial(stax.Conv(out_chan=channels, filter_shape=(3,3), strides=strides, padding='SAME',
                                 W_std=W_std, b_std=b_std), 
                       stax.Relu(do_backprop=True))

def ConvGroup(n, channels, stride, W_std, b_std, last_stride=False):
    blocks = []
    if last_stride:
        for i in range(n-1):
            blocks += [ConvBlock(channels, W_std, b_std, stride)]
        blocks += [ConvBlock(channels, W_std, b_std, (2, 2))]
    
    else:
        for i in range(n):
            blocks += [ConvBlock(channels, W_std, b_std, stride)]
        
    return stax.serial(*blocks)
        
def VGG19(class_num=class_num):
    return stax.serial(
        ConvGroup(n=2, channels=64 , stride=(1,1), W_std=W_std, b_std=b_std, last_stride=False),
        ConvGroup(n=2, channels=128, stride=(1,1), W_std=W_std, b_std=b_std, last_stride=False),
        ConvGroup(n=4, channels=256, stride=(1,1), W_std=W_std, b_std=b_std, last_stride=False),
        ConvGroup(n=4, channels=512, stride=(1,1), W_std=W_std, b_std=b_std, last_stride=False),
        ConvGroup(n=4, channels=512, stride=(1,1), W_std=W_std, b_std=b_std, last_stride=False),
        stax.Flatten(),
        stax.Dense(512), stax.Relu(do_backprop=True),
        stax.Dense(class_num))

def simple_net(class_num=class_num):
    return stax.serial(
        ConvGroup(n=3, channels=64 , stride=(1,1), W_std=W_std, b_std=b_std, last_stride=False),
        stax.Flatten(),
        stax.Dense(512, W_std=W_std), stax.Relu(do_backprop=True),
        stax.Dense(class_num, W_std=W_std))

In [9]:
init_fn, apply_fn, kernel_fn = simple_net(class_num)

In [10]:
batch_kernel_fn = nt.batch(kernel_fn, batch_size=50, store_on_device=False)

In [11]:
def model_fn(kernel_fn, x_train=None, x_test=None, fx_train_0=0., fx_test_0=0., t=None, ntk_train_train=None):
    # Kernel
    if ntk_train_train is None:
        ntk_train_train = kernel_fn(x_train, x_train, 'ntk')
    
    ntk_test_train = kernel_fn(x_test, x_train, 'ntk')
    # Prediction
    predict_fn = nt.predict.gradient_descent_mse(ntk_train_train, y_train, diag_reg=diag_reg) # no convariance
    
    return predict_fn(t, fx_train_0, fx_test_0, ntk_test_train) # fx_train_0, fx_test_0 = (0, 0) for infinite width

In [12]:
# 1 testing accuracy: 0.5100
# 2 testing accuracy: 0.6400
# 4 testing accuracy: 0.7800
# 8 testing accuracy: 0.8500
# 16 testing accuracy: 0.8800
# 32 testing accuracy: 0.8900
# 64 testing accuracy: 0.9000
# 128 testing accuracy: 0.9000
# 256 testing accuracy: 0.8900
# 512 testing accuracy: 0.9000
# 1024 testing accuracy: 0.9100
# 2048 testing accuracy: 0.9000
# 4096 testing accuracy: 0.9000
# 8192 testing accuracy: 0.9000
# 16384 testing accuracy: 0.9000

# loss

In [13]:
@jit
def l2_loss_v1(logits, labels, weighting=1):
    """
    Tensorflow version of L2 loss (without sqrt)
    """
    return np.sum(((logits - labels)**2) * weighting) / 2
    
@jit
def l2_loss_v2(logits, lables):
    """
    Normal L2 loss
    """
    return np.linalg.norm(logits - labels)

@jit
def cross_entropy_loss(logits, lables):
    return -np.sum(logsoftmax(logits) * lables)
    
@jit
def mse_loss(logits, lables):
    return 0.5 * np.mean((logits - lables) ** 2)

# attack algorithms

In [14]:
def fast_gradient_method_batch(model_fn, kernel_fn, obj_fn, grads_fn, ntk_train_train, x_train=None, y_train=None, 
                               x_test=None, y=None, t=None, loss_weighting=None, fx_train_0=0., fx_test_0=0., eps=0.3, 
                               norm=np.inf, clip_min=None, clip_max=None, targeted=False):
    
    if norm not in [np.inf, 2]:
        raise ValueError("Norm order must be either np.inf or 2.")
        
    x = x_test
        
    # test independent
    if obj_fn == 'untargeted':
        grads = grads_fn(x_train, x_test, y_train, y, kernel_fn, ntk_train_train, t)
        
    else:
        raise ValueError("Objective function must be either train(ntk_train_train) or test(predict_fn)")

    return grads

# adv_x generation

In [15]:
def inv(k):
        #inverse with diag_reg
        return onp.linalg.inv(k + diag_reg * onp.eye(k.shape[0]))

In [16]:
def test_loss_adv_mse(x_train, x_test, y_train, y, kernel_fn, ntk_train_train=None, t=None, diag_reg=diag_reg):

    # ntk_test_train = kernel_fn(x_test[None], x_train, 'ntk')

    predict_fn = nt.predict.gradient_descent_mse_ensemble(kernel_fn, x_train, y_train, diag_reg=diag_reg)

    pred = predict_fn(None, x_test[None], 'ntk', True)

    variance_test = np.diag(pred.covariance)
    
    loss = -np.sum(logsoftmax(pred.mean) * y) - lamb * np.sum(variance_test)
    return loss
    
test_mse_grads_fn = jit(vmap(grad(test_loss_adv_mse, argnums=1), in_axes=(None, 0, None, 0, None, None, None), 
                             out_axes=0), static_argnums=(4,))

In [17]:
kernel_fn = jit(kernel_fn, static_argnums=(2,))

In [18]:
def gen_adv_x(kernel_fn, x_train, x_test, y_test, t=None, train_batch=50):
    
    num_iter = x_train.shape[0] // train_batch
    grads = 0
    for idx in range(num_iter):
        x_train_batch = x_train[idx*train_batch: (idx+1)*train_batch]
        y_train_batch = y_train[idx*train_batch: (idx+1)*train_batch]
        
        # ntk_train_train     = kernel_fn(x_train_batch, None, 'ntk')
        # ntk_train_train_inv = inv(ntk_train_train)
    
        # FGSM
        grads += fast_gradient_method_batch(model_fn=model_fn, kernel_fn=kernel_fn, obj_fn='untargeted', 
                                            grads_fn=test_mse_grads_fn, x_train=x_train_batch, y_train=y_train_batch, 
                                            x_test=x_test, y=y_test, t=t, eps=eps, clip_min=0, clip_max=1, ntk_train_train=None)
    
    perturbation = eps * np.sign(grads)
    adv_x_FGSM = x_test + perturbation
    adv_x_FGSM = np.clip(adv_x_FGSM, a_min=0, a_max=1)

    return adv_x_FGSM

In [19]:
from tqdm import tqdm

In [20]:
adv_x_FGSM, adv_x_IFGSM_100 = {}, {}
for t in time:
    adv_x_FGSM[t] = []
    adv_x_IFGSM_100[t] = []
    # print("generating time:", t)
    
    for batch_id in tqdm(range(test_size//test_batch_size)):
        fgsm = gen_adv_x(kernel_fn,
                         x_train,
                         x_test[batch_id*test_batch_size:(batch_id+1)*test_batch_size], 
                         y_test[batch_id*test_batch_size:(batch_id+1)*test_batch_size],
                         t)
        
        adv_x_FGSM[t].append(fgsm)
        # adv_x_IFGSM_100[t].append(ifgsm)

100%|██████████| 100/100 [22:54<00:00, 13.74s/it]


In [21]:
onp.save('./batch_NTK_simple_imagenet_increase_variance_lambda=%d_time=%d.npy'%(lamb, time[0]), 
         onp.concatenate(adv_x_FGSM[t]))