In [1]:
import numpy as onp
import tensorflow as tf

from neural_tangents import stax

from functools import partial
from tqdm import tqdm
# Attacking
from cleverhans.utils import clip_eta, one_hot

# Plotting
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import *

sns.set_style(style='white')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

"""
diag_reg:
    a scalar representing the strength of the diagonal regularization for
    `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
    Cholesky factorization or eigendecomposition.
"""
diag_reg = 1e-5

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
tf.keras.backend.set_floatx('float64')

In [2]:
device = 1
tick   = 32

In [3]:
# [1, 2, 4, 8, 16, 32]

In [4]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only use the first GPU
    try:
        tf.config.experimental.set_visible_devices(gpus[device], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
    except RuntimeError as e:
    # Visible devices must be set before GPUs have been initialized
        print(e)

2 Physical GPUs, 1 Logical GPU


# data

In [5]:
DATASET = 'mnist'
class_num   = 10
image_shape = None

train_size = 10000
valid_size = 1024
test_size  = 10000

if DATASET =='mnist':
    image_shape = (28, 28, 1)
elif DATASET == 'cifar10':
    image_shape = (32, 32, 3)

In [6]:
x_train_all, y_train_all, x_test_all, y_test_all = tuple(onp.array(x) for x in get_dataset(DATASET, None, None, 
                                                                                  do_flatten_and_normalize=False))

































In [7]:
# shuffle
seed = 0
x_train_all, y_train_all = shaffle(x_train_all, y_train_all, seed)

In [8]:
# down sample
x_train = x_train_all[:train_size]
y_train = y_train_all[:train_size]

x_valid = x_train_all[train_size:train_size+valid_size]
y_valid = y_train_all[train_size:train_size+valid_size]

x_test = x_test_all[:test_size]
y_test = y_test_all[:test_size]

In [9]:
x_train, x_valid, x_test = x_train.reshape((-1, *image_shape)), x_valid.reshape((-1, *image_shape)), x_test.reshape((-1, *image_shape))

# finite width

In [10]:
layers = tf.keras.layers

In [11]:
batch = 64
width = 32

In [12]:
valid_ds = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
valid_ds = valid_ds.shuffle(5000)
valid_ds = valid_ds.batch(batch)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(20000)
train_ds = train_ds.batch(batch)

In [13]:
width_tick = [2**i for i in range(6)]

In [14]:
# for multi in width_tick:
#     img_input = layers.Input(shape=image_shape)
#     x = layers.Conv2D(width * multi, (3, 3))(img_input)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.Flatten()(x)
#     out = layers.Dense(10)(x)
    
#     model = tf.keras.Model(img_input, out)
#     model.compile(optimizer='sgd',
#               loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
#               metrics=['accuracy'])
#     model.fit(x=train_ds, validation_data=valid_ds, epochs=15)
#     model.save_weights('./conv_10_weights/conv_10_with_multi_%s'%(str(multi)))

In [15]:
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

In [16]:
multi = tick

In [None]:
img_input = layers.Input(shape=image_shape)
x = layers.Conv2D(width * multi, (3, 3))(img_input)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.Flatten()(x)
out = layers.Dense(10)(x)

model = tf.keras.Model(img_input, out)
model.load_weights('./conv_10_weights/conv_10_with_multi_%s'%(str(multi)))

@tf.function
def get_cross_entropy_hessian(y, x):
    y_pred = model(x)
    cross_en = loss(y, y_pred)
    return tf.hessians(cross_en, x)

hessians = 0
for x, y in zip(tqdm(x_test), y_test):
    x_tensor = tf.convert_to_tensor(x.reshape(-1, *image_shape))
    y_tensor = tf.convert_to_tensor(y.reshape(-1, class_num))
    hs = get_cross_entropy_hessian(y_tensor, x_tensor)
    hs = tf.reshape(hs, (784, 784))
    hessians += hs

onp.save('./conv_10_hessians/conv_10_hessians_with_multi_%s'%(str(multi)), hessians)

  6%|▋         | 630/10000 [4:54:56<73:04:16, 28.07s/it]

In [1]:
hessians = 0
for x, y in zip(tqdm(x_train[:1]), y_train[:1]):
    x_tensor = tf.convert_to_tensor(x.reshape(-1, *image_shape))
    y_tensor = tf.convert_to_tensor(y.reshape(-1, class_num))
    hs = get_cross_entropy_hessian(y_tensor, x_tensor)
    hs = tf.reshape(hs, (784, 784))
    hessians += hs

NameError: name 'tqdm' is not defined

In [None]:
tf.reduce_mean(hessians)

In [None]:
eig_val, eig_vec = onp.linalg.eigh(hessians)

In [None]:
eig_val[::-1]

In [None]:
tf.reduce_mean(hessians)

In [None]:
eig_val, eig_vec = onp.linalg.eigh(hessians)

In [None]:
eig_val[::-1]

# infinite width

In [None]:
batch_inf_kernel_fn = nt.batch(kernel_fn, batch_size=256, store_on_device=False)

kernel_train_m = batch_inf_kernel_fn(x_train_down, None, 'ntk')

eigval_inf, eigv_inf = np.linalg.eigh(kernel_train_m + np.eye(train_size)*diag_reg)

In [None]:
def plt_samples(arr, attack_type, layer):
    fig, axs = plt.subplots(2, 4, figsize=(6, 3), sharex=True)
    for row, ax in enumerate(axs):
        for idx, a in enumerate(ax):
            img = arr[idx + row*4].reshape(image_shape)
            a.axis('off')
            a.xaxis.set_visible(False)
            a.yaxis.set_visible(False)
            a.imshow(img, cmap='gray', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

In [None]:
def save_samples(arr, attack_type, layer):
    fig, axs = plt.subplots(2, 4, figsize=(6, 3), sharex=True)
    for row, ax in enumerate(axs):
        for idx, a in enumerate(ax):
            img = arr[idx + row*4].reshape(image_shape)
            a.axis('off')
            a.xaxis.set_visible(False)
            a.yaxis.set_visible(False)
            a.imshow(img, cmap='gray', vmin=0, vmax=1)

    plt.tight_layout()
    plt.savefig("./fig-%s-untargeted/%s_layer_%d.png"%(DATASET ,attack_type, layer+1), dpi=150)
    plt.show()