In [1]:
import numpy as onp
import tensorflow as tf

from neural_tangents import stax

from functools import partial
from tqdm import tqdm
# Attacking
from cleverhans.utils import clip_eta, one_hot

# Plotting
from IPython.display import set_matplotlib_formats
set_matplotlib_formats('pdf', 'svg')
import matplotlib.pyplot as plt
import seaborn as sns
from mpl_toolkits.axes_grid1 import make_axes_locatable

from utils import *

sns.set_style(style='white')
colors = plt.rcParams['axes.prop_cycle'].by_key()['color']

"""
diag_reg:
    a scalar representing the strength of the diagonal regularization for
    `k_train_train`, i.e. computing `k_train_train + diag_reg * I` during
    Cholesky factorization or eigendecomposition.
"""
diag_reg = 1e-5

import os
os.environ["XLA_PYTHON_CLIENT_MEM_FRACTION"] = "1.0"
os.environ["XLA_PYTHON_CLIENT_ALLOCATOR"] = "platform"
tf.keras.backend.set_floatx('float64')

In [2]:
device = 3
tick   = 8

In [3]:
# [1, 2, 4, 8, 16, 32]

In [4]:
gpus = tf.config.experimental.list_physical_devices('GPU')
if gpus:
  # Restrict TensorFlow to only use the first GPU
    try:
        tf.config.experimental.set_visible_devices(gpus[device], 'GPU')
        logical_gpus = tf.config.experimental.list_logical_devices('GPU')
        print(len(gpus), "Physical GPUs,", len(logical_gpus), "Logical GPU")
    except RuntimeError as e:
    # Visible devices must be set before GPUs have been initialized
        print(e)

4 Physical GPUs, 1 Logical GPU


# data

In [5]:
DATASET = 'mnist'
class_num   = 10
image_shape = None

train_size = 10000
valid_size = 1024
test_size  = 10000

if DATASET =='mnist':
    image_shape = (28, 28, 1)
elif DATASET == 'cifar10':
    image_shape = (32, 32, 3)

In [6]:
x_train_all, y_train_all, x_test_all, y_test_all = tuple(onp.array(x) for x in get_dataset(DATASET, None, None, 
                                                                                  do_flatten_and_normalize=False))

































In [7]:
# shuffle
seed = 0
x_train_all, y_train_all = shaffle(x_train_all, y_train_all, seed)

In [8]:
# down sample
x_train = x_train_all[:train_size]
y_train = y_train_all[:train_size]

x_valid = x_train_all[train_size:train_size+valid_size]
y_valid = y_train_all[train_size:train_size+valid_size]

x_test = x_test_all[:test_size]
y_test = y_test_all[:test_size]

In [9]:
x_train, x_valid, x_test = x_train.reshape((-1, *image_shape)), x_valid.reshape((-1, *image_shape)), x_test.reshape((-1, *image_shape))

# finite width

In [10]:
layers = tf.keras.layers

In [11]:
batch = 64
width = 32

In [12]:
valid_ds = tf.data.Dataset.from_tensor_slices((x_valid, y_valid))
valid_ds = valid_ds.shuffle(5000)
valid_ds = valid_ds.batch(batch)

train_ds = tf.data.Dataset.from_tensor_slices((x_train, y_train))
train_ds = train_ds.shuffle(20000)
train_ds = train_ds.batch(batch)

In [13]:
width_tick = [2**i for i in range(6)]

In [14]:
# for multi in width_tick:
#     img_input = layers.Input(shape=image_shape)
#     x = layers.Conv2D(width * multi, (3, 3))(img_input)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.ReLU()(x)
#     x = layers.Conv2D(width * multi, (3, 3))(x)
#     x = layers.Flatten()(x)
#     out = layers.Dense(10)(x)
    
#     model = tf.keras.Model(img_input, out)
#     model.compile(optimizer='sgd',
#               loss=tf.keras.losses.CategoricalCrossentropy(from_logits=True),
#               metrics=['accuracy'])
#     model.fit(x=train_ds, validation_data=valid_ds, epochs=15)
#     model.save_weights('./conv_10_weights/conv_10_with_multi_%s'%(str(multi)))

In [15]:
loss = tf.keras.losses.CategoricalCrossentropy(from_logits=True)

In [16]:
multi = tick

In [17]:
img_input = layers.Input(shape=image_shape)
x = layers.Conv2D(width * multi, (3, 3))(img_input)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.ReLU()(x)
x = layers.Conv2D(width * multi, (3, 3))(x)
x = layers.Flatten()(x)
out = layers.Dense(10)(x)

model = tf.keras.Model(img_input, out)
model.load_weights('./conv_10_weights/conv_10_with_multi_%s'%(str(multi)))

@tf.function
def get_cross_entropy_hessian(y, x):
    y_pred = model(x)
    cross_en = loss(y, y_pred)
    return tf.hessians(cross_en, x)

hessians = 0
for x, y in zip(tqdm(x_test), y_test):
    x_tensor = tf.convert_to_tensor(x.reshape(-1, *image_shape))
    y_tensor = tf.convert_to_tensor(y.reshape(-1, class_num))
    hs = get_cross_entropy_hessian(y_tensor, x_tensor)
    hs = tf.reshape(hs, (784, 784))
    hessians += hs

onp.save('./conv_10_hessians/conv_10_hessians_with_multi_%s'%(str(multi)), hessians)

100%|██████████| 10000/10000 [10:18:22<00:00,  3.71s/it]


In [46]:
hessians = 0
for x, y in zip(tqdm(x_train[:1]), y_train[:1]):
    x_tensor = tf.convert_to_tensor(x.reshape(-1, *image_shape))
    y_tensor = tf.convert_to_tensor(y.reshape(-1, class_num))
    hs = get_cross_entropy_hessian(y_tensor, x_tensor)
    hs = tf.reshape(hs, (784, 784))
    hessians += hs

100%|██████████| 1/1 [00:02<00:00,  2.47s/it]


In [35]:
tf.reduce_mean(hessians)

<tf.Tensor: shape=(), dtype=float64, numpy=2.1161763306741252e-10>

In [38]:
eig_val, eig_vec = onp.linalg.eigh(hessians)

In [40]:
eig_val[::-1]

array([ 8.26016523e-05,  1.21976424e-05,  5.54193370e-06,  2.56799612e-06,
        1.42420527e-06,  7.20792864e-09,  1.40965978e-09,  5.93886349e-10,
        5.11320142e-16,  2.66343834e-16,  1.93821728e-16,  1.46800790e-16,
        1.15894900e-16,  8.92877028e-17,  8.56220034e-17,  6.39855184e-17,
        6.09232088e-17,  4.97202809e-17,  4.61331388e-17,  3.76553164e-17,
        3.23790510e-17,  3.13244798e-17,  3.03724620e-17,  2.57050842e-17,
        2.33337727e-17,  2.23258702e-17,  2.13995145e-17,  2.02549742e-17,
        1.82971374e-17,  1.74412206e-17,  1.71659888e-17,  1.34252936e-17,
        1.26697242e-17,  1.20188199e-17,  1.11753011e-17,  1.07811134e-17,
        1.05506156e-17,  1.05487139e-17,  9.39240984e-18,  9.14627606e-18,
        8.56487228e-18,  8.36108877e-18,  7.76613225e-18,  7.38376100e-18,
        7.25767927e-18,  7.03302640e-18,  6.74340302e-18,  6.20850763e-18,
        5.96069364e-18,  5.31184402e-18,  5.13702457e-18,  5.08019726e-18,
        4.41532343e-18,  

In [27]:
tf.reduce_mean(hessians)

<tf.Tensor: shape=(), dtype=float64, numpy=2.1426886522773e-11>

In [47]:
eig_val, eig_vec = onp.linalg.eigh(hessians)

In [48]:
eig_val[::-1]

array([ 4.86541050e-06,  4.34637530e-07,  2.00199333e-07,  4.55416120e-08,
        1.45985184e-08,  4.95647587e-10,  6.30461698e-11,  2.06195810e-11,
        4.94802232e-16,  2.42134474e-16,  1.38932734e-16,  9.72874218e-17,
        7.58969295e-17,  6.96353123e-17,  5.38906470e-17,  4.98175448e-17,
        3.98871573e-17,  3.69615352e-17,  3.33388241e-17,  2.42065422e-17,
        2.34771033e-17,  2.27610977e-17,  2.14597037e-17,  2.02243318e-17,
        1.99716201e-17,  1.67022723e-17,  1.53973229e-17,  1.51043610e-17,
        1.39582344e-17,  1.35946273e-17,  1.30107905e-17,  1.17665291e-17,
        1.14108146e-17,  1.09678459e-17,  1.00499354e-17,  9.80938466e-18,
        9.20369226e-18,  9.11658649e-18,  8.37193248e-18,  8.07116442e-18,
        7.25149081e-18,  6.95204388e-18,  6.45037286e-18,  6.37257754e-18,
        6.05727611e-18,  5.60902975e-18,  5.33482818e-18,  4.93918914e-18,
        4.72139566e-18,  4.53249314e-18,  4.18681028e-18,  3.95427722e-18,
        3.87740907e-18,  

# infinite width

In [11]:
batch_inf_kernel_fn = nt.batch(kernel_fn, batch_size=256, store_on_device=False)

kernel_train_m = batch_inf_kernel_fn(x_train_down, None, 'ntk')

eigval_inf, eigv_inf = np.linalg.eigh(kernel_train_m + np.eye(train_size)*diag_reg)

  'fit the dataset.' % (batch_size, n2_batch_size))


In [30]:
def plt_samples(arr, attack_type, layer):
    fig, axs = plt.subplots(2, 4, figsize=(6, 3), sharex=True)
    for row, ax in enumerate(axs):
        for idx, a in enumerate(ax):
            img = arr[idx + row*4].reshape(image_shape)
            a.axis('off')
            a.xaxis.set_visible(False)
            a.yaxis.set_visible(False)
            a.imshow(img, cmap='gray', vmin=0, vmax=1)

    plt.tight_layout()
    plt.show()

In [31]:
def save_samples(arr, attack_type, layer):
    fig, axs = plt.subplots(2, 4, figsize=(6, 3), sharex=True)
    for row, ax in enumerate(axs):
        for idx, a in enumerate(ax):
            img = arr[idx + row*4].reshape(image_shape)
            a.axis('off')
            a.xaxis.set_visible(False)
            a.yaxis.set_visible(False)
            a.imshow(img, cmap='gray', vmin=0, vmax=1)

    plt.tight_layout()
    plt.savefig("./fig-%s-untargeted/%s_layer_%d.png"%(DATASET ,attack_type, layer+1), dpi=150)
    plt.show()