In [1]:
import math
import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt
%matplotlib inline

import datasets

np.random.seed(0)
tf.set_random_seed(0)

In [3]:
train_x.shape

(58606, 3072)

In [2]:
train_set, test_set = datasets.load_SVHN()


#split validation set
train_set, val_set = datasets.split_validation(train_set, percentage=0.2)

train_x, train_y = train_set
val_x, val_y = val_set



n_samples = len(train_x)
n_val = len(val_x)
print(n_samples)
print(n_val)

loading SVHN training images...
loading SVHN test images...
58606
14651


In [4]:
# helper functions
def xavier_init(fan_in, fan_out, constant=1): 
    """ Xavier initialization of network weights"""
    low = -constant*np.sqrt(6.0/(fan_in + fan_out)) 
    high = constant*np.sqrt(6.0/(fan_in + fan_out))
    return tf.random_uniform((fan_in, fan_out), 
                             minval=low, maxval=high, 
                             dtype=tf.float32)

sigmoid = tf.nn.sigmoid
elu = tf.nn.elu

def bias(size, zero=False):
    if zero:
        return tf.Variable(tf.zeros([size], dtype=tf.float32))
    else:
        return tf.Variable(tf.random_normal([size], stddev=0.1))

def conv(tensor, kernel_dims):
    ksize, n_in, n_out = kernel_dims
    kernels = tf.Variable(tf.random_normal([ksize, ksize, n_in, n_out],  stddev=0.1))
    return tf.nn.conv2d(tensor,kernels, strides=[1, 2, 2, 1], padding='SAME') + bias(n_out)

def deconv(tensor, kernel_dims, out_dim, stride=[1,2,2,1]):
    ksize, n_out, n_in = kernel_dims
    kernels = tf.Variable(tf.random_normal([ksize, ksize, n_out, n_in], stddev=0.1))
    out_dim = tf.stack(out_dim)
    return tf.nn.conv2d_transpose(tensor, kernels, out_dim, strides=stride, padding='SAME') + bias(n_out)

def dense(tensor, in_size, out_size):
    weights = tf.Variable(xavier_init(in_size, out_size))
    return tf.matmul(tensor, weights) + bias(out_size)
    

def clip(tensor, _max=None, _min=None):
    return tf.clip_by_value(tensor, clip_value_min=_min, clip_value_max=_max)

def clippedAdam(loss, weights):
    reconstr_grads = tf.gradients(loss, w)
    
    # Clip gradient
    def ClipIfNotNone(grad):
        if grad is None:
            return grad
        return tf.clip_by_value(grad, -1, 1)
    
    reconstr_grads, _ = tf.clip_by_global_norm(reconstr_grads, 5.)
    reconstr_grads = [ClipIfNotNone(grad) for grad in reconstr_grads]
    
    optimizer = tf.train.AdamOptimizer(learning_rate=0.001)
    return optimizer.apply_gradients(zip(reconstr_grads, w))
    
    

In [5]:
#model
n_input = 32*32*3
n1 = 32
n2 = 64
n3 = 128
n_z = 100
batch_size = 100
ksize = 5

x = tf.placeholder(tf.float32, [None, n_input],'X')
n_ex = tf.shape(x)[0] #number of examples

#attack layer starts out at zero
attack = tf.Variable(tf.zeros([n_input], dtype=tf.float32),name='Attack')
adv_x = clip(x + attack, _min=0, _max=1)

adv_im = tf.reshape(adv_x, shape=[-1, 32, 32, 3])

#encoder
conv_enc1  = elu(conv(adv_im, [ksize, 3, n1]))
conv_enc2  = elu(conv(conv_enc1, [ksize, n1, n2]))
conv_enc3  = elu(conv(conv_enc2, [ksize, n2, n3]))
flat = tf.reshape(conv_enc3, [-1, 4*4*n3])
dens_enc4 = elu(dense(flat, 4*4*n3, 512))

z_mean = dense(dens_enc4, 512, n_z)
z_log_var = dense(dens_enc4, 512, n_z)


#sample the latent variables
eps = tf.random_normal((batch_size, n_z), 0, 1, dtype=tf.float32)
z = z_mean + eps*tf.sqrt(tf.exp(z_log_var)) #sampling z from normal


#generator
from_z = elu(dense(z, n_z, 512))
im_z = tf.reshape(from_z, [-1,4,4,32],name='im_z')
deconv_dec1 = elu(deconv(im_z, [ksize, n3, 32], out_dim=[n_ex, 8, 8, n3]))
deconv_dec2 = elu(deconv(deconv_dec1, [ksize, n2, n3], out_dim=[n_ex, 16, 16, n2]))
deconv_dec3 = elu(deconv(deconv_dec2, [ksize, n1, n2], out_dim=[n_ex, 32, 32, n1]))

out_mu = deconv(deconv_dec3, [ksize, 3, n1], out_dim=[n_ex, 32, 32, 3], stride=[1,1,1,1])
out_log_var = deconv(deconv_dec3, [ksize, 3, n1], out_dim=[n_ex, 32, 32, 3], stride=[1,1,1,1])

reconstr_mean = sigmoid(tf.reshape(out_mu, [-1,32*32*3]))
reconstr_log_var = tf.reshape(out_log_var, [-1,32*32*3])

In [6]:
sess = tf.InteractiveSession()

In [7]:
#Getting model variables and attack variable
all_vars = sess.graph.get_collection('trainable_variables')
w = [var for var in all_vars if 'Attack' not in var.name]

c = - 0.5 * math.log(2 * math.pi)
def log_normal2(x, mean, log_var, eps=1e-8):
    return c - log_var / 2 - tf.pow(x - mean, 2) / (2 * tf.exp(log_var) + eps)

def kl_divergence(mean1, log_var1, mean2, log_var2, eps=1e-8):
    mean_term = 0.5 * (tf.exp(log_var1) + tf.pow(mean1 - mean2, 2)) \
        / (tf.exp(log_var2) + eps)
    return mean_term + 0.5 * log_var2 - 0.5 * log_var1 - 0.5

#loss functions
log_px_given_z = tf.reduce_sum(log_normal2(x, reconstr_mean, reconstr_log_var),axis=1)
kl_term = tf.reduce_sum(kl_divergence(z_mean, z_log_var, 0., 1.), axis=1)
cost = -tf.reduce_mean(-kl_term + log_px_given_z)

#using Adam optimizer
optimizer = clippedAdam(cost, w)


#Adversarial Cost
Target = tf.placeholder(tf.float32, [None, n_z])
C = tf.placeholder(tf.float32,name='C')
target_cost = tf.reduce_mean(tf.square(Target - z_mean)) + C*tf.nn.l2_loss(attack)

target_opt = tf.train.AdamOptimizer(learning_rate=0.001).minimize(target_cost,var_list=attack)

init_attack = tf.variables_initializer([attack])

init = tf.global_variables_initializer()

In [8]:
def partial_fit(X):
    """Train model based on mini-batch of input data.

    Return cost of mini-batch.
    """
    opt, c = sess.run((optimizer, cost), 
                              feed_dict={x: X})
    return c

def get_cost(X):
    c = sess.run((cost), feed_dict={x: X})
    return c


def reconstruct(X):
    """ Use VAE to reconstruct given data. """
    return sess.run(out_mean, feed_dict={x: X})


In [None]:
training_epochs = 100
display_step = 1

sess.run(init)
# Training cycle
val_losses = []
losses = []
for epoch in range(training_epochs):
    avg_cost = 0.
    avg_val = 0.
    # Loop over all batches
    for batch_xs, ys in datasets.batches(train_x, train_y, batch_size=batch_size):
        if len(batch_xs) != batch_size:
            continue
        # Fit training using batch data
        c = partial_fit(batch_xs)
        # Compute average loss
        avg_cost += c / n_samples * batch_size

    #validation set costs
    for batch_xs, ys in datasets.batches(val_x, val_y, batch_size=batch_size):
        if len(batch_xs) != batch_size:
            continue
        # Fit training using batch data
        c = get_cost(batch_xs)
        # Compute average loss
        avg_val += c / n_val * batch_size
    
    val_losses.append(avg_val)
    losses.append(avg_cost)
    # Display logs per epoch step
    if epoch % display_step == 0:
        print ("EPOCH: %04d, COST: %03.5f, VAL_COST: %03.5f" % (epoch+1, avg_cost, avg_val))
print ("EPOCH: %04d, COST: %03.5f, VAL_COST: %03.5f" % (epoch+1, avg_cost, avg_val))
plt.plot(losses)
_= plt.plot(val_losses)

EPOCH: 0001, COST: 5138379213483.89160, VAL_COST: 19409.82872
EPOCH: 0011, COST: 19206.62940, VAL_COST: 19166.33289
EPOCH: 0021, COST: 19198.95537, VAL_COST: 19144.34112
EPOCH: 0031, COST: 19199.66780, VAL_COST: 19170.33343
EPOCH: 0041, COST: 19198.17456, VAL_COST: 19127.20752


In [None]:
batch = datasets.batches(train_x, train_y, batch_size=batch_size)
x_sample, _ = next(batch)
x_reconstruct = reconstruct(x_sample)
plt.figure(figsize=(8, 12))
for i in range(5):
    plt.subplot(5, 2, 2*i + 1)
    plt.imshow(x_sample[i].reshape(32, 32, 3), vmin=0, vmax=1, cmap="gray")
    plt.title("Test input")
    plt.colorbar()
    plt.subplot(5, 2, 2*i + 2)
    plt.imshow(x_reconstruct[i].reshape(32, 32, 3), vmin=0, vmax=1, cmap="gray")
    plt.title("Reconstruction")
    plt.colorbar()
plt.tight_layout()

# Adversarial Attack

In [None]:
def plot_images(images=[],titles=[],suptitle=''):
    num_colums = 3
    num_imgs = len(images)
    num_rows = 2
    
    plt.figure(figsize=(10,5))
    plt.suptitle(suptitle,fontsize=16)
    for i in range(num_imgs):
        plt.subplot(num_rows, num_colums, i+1)
        plt.imshow(images[i].reshape(32, 32, 3),vmin=0,vmax=1,cmap='gray')
        plt.title(titles[i])
        plt.axis('off')
    plt.show()

def choose_original_target(x):
    n = len(x)
    i_original = np.random.choice(n)
    i_target = np.random.choice(n)
    original, target = x[i_original], x[i_target]
    return original, target

def stack(array, stack_size=100):
    return np.array([array for i in range(stack_size)])

def get_z(X):
    z = sess.run((z_mean), feed_dict={x: X})
    return z

def reset_attack():
    z = sess.run(init_attack)

def fit_attack(feed_dict):
    target_loss, _ = sess.run((target_cost, target_opt), feed_dict=feed_dict)
    return target_loss

def get_vars(X):
    return sess.run((out_mean, attack, adv_x), feed_dict={x: X})
    
def distance(a, b):
    return np.mean(np.linalg.norm(a-b, axis=1))

def norm(a):
    return np.linalg.norm(a)

In [None]:
orig, targ = choose_original_target(train_x)
orig = stack(orig, stack_size=100)
targ = stack(targ, stack_size=100)

reset_attack()
target_z = np.array([get_z(targ)[0]])
orig_recon = reconstruct(orig)
targ_recon = reconstruct(targ)

targ_2_orig = distance(orig_recon, targ_recon)
targ_2_recon = distance(targ, targ_recon)
origrec_2_target = distance(orig_recon, targ)


plot_images(images=[orig[0],targ[0]],titles=['original','target'])

In [None]:
#Vaores de C para explorar
Cs = np.logspace(10, -20, 100, base = 2, dtype = np.float32)

points = []
titles = ['Original','Target','Original Reconstruction',
              'Adversarial image','Attack','Attacked Reconstruction']
for i, c in enumerate(Cs):
    reset_attack()
    feed_dict = {x:orig, Target:target_z, C:c}
    print('%d C: %f'%(i+1, c))
    for i in range(2000):
        loss = fit_attack(feed_dict)
        
        if i % 1000 == 0:
            print(np.mean(loss))
    
    recon, att, ad_img = get_vars(orig)
    
    size_attack = norm(att)
    dist_recon = distance(recon,targ)
    
    point = (size_attack,dist_recon)
    points.append(point)
    images = [orig[0], targ[0], orig_recon[0], ad_img[0], att, recon[0]]
    suptitle='C: %f ' % c
    plot_images(images=images,titles=titles,suptitle=suptitle)


In [None]:
def plot_results(points):
    plt.figure()
    plt.axvline(targ_2_orig, color='cyan', linewidth=2, label="Original - Target")
    plt.axhline(targ_2_recon, color='red', linewidth=2, label="Target rec. - Target")
    plt.axhline(origrec_2_target, color='DarkOrange', linewidth=2, label="Original rec. - Target")
    x,y=list(zip(*points))
    plt.scatter(x,y)
    plt.ylabel("Adversarial rec. - Target")
    plt.xlabel("Distortion")
    plt.legend()
    
plot_results(points)