In [1]:
import numpy as onp
import jax
import jax.numpy as np

from jax import lax, random
from jax.api import grad, jit, vmap
from jax.config import config
from jax.experimental import optimizers
from jax.experimental.stax import logsoftmax

config.update('jax_enable_x64', True)

from neural_tangents import stax

from functools import partial

# Attacking
from cleverhans.utils import clip_eta, one_hot

# Plotting
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import *

sns.set_style(style='white')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

"""
diag_reg:
    a scalar representing the strength of the diagonal regularization for
    `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
    Cholesky factorization or eigendecomposition.
"""
diag_reg = 1e-5

import os

In [2]:
DATASET = 'mnist'
class_num   = 10
train_size  = 4096

#mnist cntk - 128


test_size   = 2048
image_shape = None

if DATASET =='mnist':
    image_shape = (28, 28, 1)
elif DATASET == 'cifar':
    image_shape = (32, 32, 3)

In [3]:
x_train_all, y_train_all, x_test_all, y_test_all = tuple(onp.array(x) for x in get_dataset(DATASET, None, None, 
                                                                                  do_flatten_and_normalize=False))

In [4]:
# shuffle
seed = 0
x_train_all, y_train_all = shaffle(x_train_all, y_train_all, seed)

In [5]:
# down sample
x_train = x_train_all[:train_size]
y_train = y_train_all[:train_size]

x_test = x_test_all[:test_size]
y_test = y_test_all[:test_size]

In [6]:
x_train, x_test = x_train.reshape((-1, *image_shape)), x_test.reshape((-1, *image_shape))

In [7]:
device_id = jax.devices()[0]

In [8]:
# move to gpu

# x_train = jax.device_put(x_train, device=device_id)
# y_train = jax.device_put(y_train, device=device_id)

# x_test = jax.device_put(x_test, device=device_id)
# y_test = jax.device_put(y_test, device=device_id)

In [9]:
def accuracy(mean, ys):
    return np.mean(np.argmax(mean, axis=-1) == np.argmax(ys, axis=-1))

In [10]:
def ConvBlock(channels, W_std, b_std, strides=(1,1)):
    return stax.serial(stax.Conv(out_chan=channels, filter_shape=(3,3), strides=strides, padding='SAME',
                                 W_std=W_std, b_std=b_std), 
                       stax.Erf())

def ConvGroup(n, channels, stride, W_std, b_std):
    blocks = []
    for i in range(n):
        blocks += [ConvBlock(channels, W_std, b_std, stride)]
        
    return stax.serial(*blocks)
        
def ConvNet(block_size, k, W_std, b_std, class_num=class_num):
    """
    k: channel multiply
    """
    
    return stax.serial(ConvGroup(block_size, int(16 * k), (1, 1), W_std, b_std),
                       ConvGroup(block_size, int(32 * k), (2, 2), W_std, b_std),
                       ConvGroup(block_size, int(64 * k), (2, 2), W_std, b_std),
                       stax.Flatten(),
                       stax.Dense(class_num))

In [11]:
init_fn, apply_fn, kernel_fn = ConvNet(block_size=4, k=1, W_std=1, b_std=0)

In [12]:
batch_kernel_fn = nt.batch(kernel_fn, batch_size=128, store_on_device=False)

In [15]:
kernel_train_m = batch_kernel_fn(x_train, None, 'ntk')

In [20]:
kernel_test_m = batch_kernel_fn(x_test, x_train, 'ntk')

In [16]:
predict_fn = nt.predict.gradient_descent_mse(kernel_train_m, y_train, diag_reg=diag_reg)

In [28]:
_, y_predict = predict_fn(None, 0.0, 0.0, k_test_train=kernel_test_m)

In [30]:
accuracy(y_predict, y_test)

DeviceArray(0.96826172, dtype=float64)