In [40]:
import tensorflow as tf
from os import walk
import matplotlib.pyplot as plt
import numpy as np

In [72]:
model_dir = '../animal_classifier/models/augmented/augmented_64'
validate_dir = '../animal_classifier/dataset/validate'
eps = .1

image_save_dir = './data_output/images'

animals = ['butterfly', 'cat', 'chicken', 'cow', 'dog', 'elephant', 'horse', 'sheep', 'spider', 'squirrel']

In [16]:
def get_image(animal_index, photo_index):
    directory = validate_dir + '/' + animals[animal_index]
    filenames = next(walk(directory), (None, None, []))[2] #Black magic that returns list of filenames in directory
    return tf.io.read_file(directory + '/' + filenames[photo_index])

In [17]:
def get_perterbations(model, input_image, class_index):
  with tf.GradientTape() as tape:
    tape.watch(input_image)
    prediction = model(input_image)
    loss = tf.keras.losses.MSE(class_index, prediction)

  # Get the gradients of the loss w.r.t to the input image.
  gradient = tape.gradient(loss, input_image)
  # Get the sign of the gradients to create the perturbation
  signed_grad = tf.sign(gradient)
  return signed_grad

In [99]:
def do_the_thing(class_index, photo_index, show=False, save=False):
    image_raw = get_image(class_index, photo_index)
    image = tf.image.decode_image(image_raw, channels=3)
    image = tf.cast(image, tf.float32)
    image = tf.image.resize(image, (256,256))
    image = image[None, ...] #adds that goofy empty dimension
    
    model = tf.keras.models.load_model(model_dir)

    perterbations = get_perterbations(model, image, class_index)

    if show:
        plt.figure()
        plt.imshow(image[0] / 255)
        if save:
            plt.savefig(image_save_dir + '/' + animals[class_index] + '-' + str(photo_index), frameonbool = False)
        # plt.figure()
        plt.imshow(perterbations[0] * 0.5 + 0.5)
        if save:
            plt.savefig(image_save_dir + '/' + animals[class_index] + '-' + str(photo_index) + '-pert')
        # plt.figure()
        plt.imshow(image[0] / 255 - (perterbations[0] * 0.5 + 0.5) * eps)
        if save:
            plt.savefig(image_save_dir + '/' + animals[class_index] + '-' + str(photo_index) + '-adv')

    

    reg_prediction = np.array(model(image)[0])
    reg_predicted_class_index = np.argmax(reg_prediction)

    adv_prediction = np.array(model(image - perterbations * eps)[0])
    adv_predicted_class_index = np.argmax(adv_prediction)


    if adv_prediction[adv_predicted_class_index] > .5 and reg_prediction[reg_predicted_class_index] > .5 and reg_predicted_class_index != adv_predicted_class_index:
        print('-------  {}  -------'.format(str(photo_index)))
        print('Initial Prediction:', animals[reg_predicted_class_index], reg_prediction[reg_predicted_class_index])
        print('Adversarial Prediction:', animals[adv_predicted_class_index], adv_prediction[adv_predicted_class_index])
        print(animals[reg_predicted_class_index], adv_prediction[reg_predicted_class_index])

In [100]:
animal_index = 5
photo_index = 127 #Dog: 154 191 Elephant: 1 28 33 66 127 153

# for i in range(0,200):
#     do_the_thing(animal_index, i)
do_the_thing(animal_index, photo_index, show=False, save=False)

-------  127  -------
Initial Prediction: elephant 0.6367365
Adversarial Prediction: horse 0.54679334
elephant 0.40088448
