In [None]:
import tensorflow as tf
import numpy as np
import matplotlib.pyplot as plt
import os

In [None]:
MODEL_UPDATE = False
MODEL_EXISTS = os.path.isfile('models/discriminator.h5')

In [None]:
def preprocess(image):
	"""re-cast to float32 & [0, 1]"""
	return tf.cast(image, tf.float32) / 255.0

In [None]:
(X_train, y_train), (X_test, y_test) = tf.keras.datasets.mnist.load_data()

if (MODEL_EXISTS):
	model = tf.keras.models.load_model('models/discriminator.h5')
else:
	model = tf.keras.models.Sequential([
	tf.keras.layers.Flatten(input_shape=(28, 28)),
	tf.keras.layers.Dense(128, activation='relu'),
  	tf.keras.layers.Dropout(0.2),
	tf.keras.layers.Dense(10),
	])
	model.compile(
		optimizer=tf.keras.optimizers.Adam(0.001),
		loss=tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True),
		metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
	)

model.fit(preprocess(X_train), y_train, epochs=5, shuffle=True)
model.evaluate(preprocess(X_test),  y_test, verbose=2)

if (MODEL_UPDATE):
	model.save('models/discriminator.h5')
	MODEL_EXISTS = True

In [None]:
def show_prediction(image, prediction=np.array([]), title=''):
	# MNIST classes are [0, ..., 9]
	cols = np.arange(prediction.shape[0])
	# draw image
	plt.subplot(1, 2, 1)
	plt.imshow(image)
	plt.xticks([])
	plt.yticks([])
	# draw prediction bar graph
	plt.subplot(1, 2, 2)
	plt.bar(cols, prediction)
	plt.xticks(cols)
	plt.ylim(0.0, 1.0)
	plt.suptitle(title)
	plt.show()

In [None]:
def step_fgsm(x, epsilon, loss_fn):
  # calculate gradient
  with tf.GradientTape() as tape:
    tape.watch(x)
    prediction = softmax_model(image)
    one_hot_label = tf.one_hot(tf.argmax(prediction, 1), 10)
    loss = loss_fn(one_hot_label, prediction)

  gradient = tape.gradient(loss, x)
  x_adv = epsilon * tf.sign(gradient[0])
  
  return tf.stop_gradient(x_adv)

In [None]:
epsilon = 0.01
image = preprocess(X_test[50])[np.newaxis,]
loss_fn = tf.keras.losses.CategoricalCrossentropy()

softmax_model = tf.keras.models.Sequential([
	model,
	tf.keras.layers.Softmax()
])
perturbation = step_fgsm(image, epsilon, loss_fn) * 0.5 + 0.5

In [None]:
prediction = np.squeeze(softmax_model(image))
title = 'Prediction of original image'
show_prediction(np.squeeze(image), prediction, title)

plt.imshow(perturbation)
plt.xticks([])
plt.yticks([])
plt.show()

new_image = (image + perturbation) / 2.0
new_prediction = np.squeeze(softmax_model(perturbation[np.newaxis,]).numpy())
title = 'Prediction of modified image'
show_prediction(np.squeeze(new_image), new_prediction, title)