## Finding realistic adversarial examples: white-box approach

This script takes as input a classifier and a generative model and looks for 4 realistic adversarial examples.
It is a white-box approach: the inner structure of the networks (weights in particular) is used to compute a gradient.

In [5]:
import tensorflow as tf 
import numpy as np 
from tensorflow.keras import backend as K
import matplotlib.pyplot as plt
%matplotlib inline

digit_origin = 8
digit_target = 3

classifier = tf.keras.models.load_model('/Models/classifier_capacity1_simple.model', compile=False)
gan = tf.keras.models.load_model('/Models/gan_digit8_rich.h5')

In [6]:
classifier.trainable = False
combined_networkInput = tf.keras.layers.Input(shape=(10,))
x = gan(combined_networkInput)
new_shape = tf.convert_to_tensor([1,28,28,1],dtype=tf.int32)
x = tf.reshape(x,new_shape,name=None)
combined_networkOutput = classifier(x)
combined_network = tf.keras.models.Model(inputs=combined_networkInput, outputs=combined_networkOutput)
combined_network.compile(loss='binary_crossentropy', optimizer = 'adam')

In [None]:
fig = plt.figure(figsize=(2, 2))

a = 0.01

loss_object = tf.keras.losses.CategoricalCrossentropy()

def create_adv_pattern(noise,input_label):
    prediction = combined_network(noise)
    loss = loss_object(input_label,prediction)
    grad = tf.gradients(loss,noise)[0]
    signed_grad = tf.sign(grad)
    return signed_grad

found = False

start = time.time()

while(not found):
    noise = np.random.normal(0,1,size=[1,10])
    noise = tf.cast(noise,tf.float32)

    for j in range(40):
        pertubations = create_adv_pattern(noise,input_label)
        noise += a*pertubations
    result3 = combined_network(noise)
    result3 = K.eval(result3)[0][1]
    if(result3 > 0.9):
        noise = K.eval(noise)
        generated_image = gan8.predict(noise)
        print("Confidence in 3 is %f" % (count+1,result3))
        fig.add_subplot(2,2,count+1)
        plt.imshow(generated_image.reshape(1,28,28)[0],cmap='gray')
        plt.axis("off")
        found = True
        
end = time.time()
print("time: %f" % (end - start))

plt.savefig('adversarial_examples_white_box.png')
plt.show()