In [1]:
# Modules and data imports

from tensorflow.examples.tutorials.mnist import input_data
import tensorflow as tf
import os
import copy
import math
import random
import numpy as np

#########
mnist = input_data.read_data_sets('MNIST_data', one_hot=True)
sess = tf.InteractiveSession()

Extracting MNIST_data/train-images-idx3-ubyte.gz
Extracting MNIST_data/train-labels-idx1-ubyte.gz
Extracting MNIST_data/t10k-images-idx3-ubyte.gz
Extracting MNIST_data/t10k-labels-idx1-ubyte.gz


In [2]:
damage_sizes = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 1.0]
number_of_trials = 10

In [3]:
# default Google defined tensorflow functions

def weight_variable(shape):
    initial = tf.truncated_normal(shape, stddev=0.1)
    return tf.Variable(initial)

def bias_variable(shape):
    initial = tf.constant(0.1, shape=shape)
    return tf.Variable(initial)

def conv2d(x, W):
    return tf.nn.conv2d(x, W, strides=[1, 1, 1, 1], padding='SAME')

def max_pool_2x2(x):
    return tf.nn.max_pool(x, ksize=[1, 2, 2, 1],
                           strides=[1, 2, 2, 1], padding='SAME')

In [61]:
# new functions to facilitate damaging the network.

def damage_network(network_matrices, dmg_size, alpha):
    matrix_shapes = get_matrix_shapes(network_matrices)
    matrices_as_vector = vectorize_network(network_matrices)
    damage_indices = get_damage_indices(matrices_as_vector, dmg_size)
    matrices_as_vector[damage_indices] = alpha
    return reshape_matrices(matrices_as_vector, matrix_shapes)

def filter_network(network_matrices, percentile_window, alpha, filter_type):
    matrix_shapes = get_matrix_shapes(network_matrices)
    matrices_as_vector = vectorize_network(network_matrices)
    if filter_type == "inside":
        return_vector = filter_vector_in(matrices_as_vector, percentile_window, alpha)
    elif filter_type == "outside":
        return_vector = filter_vector_out(matrices_as_vector, percentile_window, alpha)
    print(len(matrices_as_vector))
    print(len(return_vector[np.nonzero(return_vector)]))
    return reshape_matrices(return_vector, matrix_shapes)
    
def get_matrix_shapes(network_matrices):
    list_of_shapes = []
    for matrix in network_matrices:
        list_of_shapes.append(list(matrix.shape))
    return list_of_shapes

def vectorize_network(network_matrices):
    vector = np.empty(0)
    for matrix in network_matrices:
        vector = np.append(vector, np.reshape(copy.copy(matrix), -1))
    return vector

def get_damage_indices(matrices_as_vector, dmg_size):
    num_elements_to_damage = int(math.floor(dmg_size * len(matrices_as_vector)))
    non_zero_elements = np.nonzero(matrices_as_vector)
    linear_indices = random.sample(range(0, len(non_zero_elements[0])), num_elements_to_damage)
    return non_zero_elements[0][linear_indices]

def alpha_zero():
    return 0

def reshape_matrices(matrix_as_vector, matrix_shapes):
    matrices = []
    vector_lengths = get_vector_lengths(matrix_shapes)
    for i in range(len(matrix_shapes)):
        matrices.append(\
            np.reshape(\
                matrix_as_vector[sum(vector_lengths[0:i+1]):sum(vector_lengths[0:(i+2)])],\
                       matrix_shapes[i]))
    return matrices

def get_vector_lengths(matrix_shapes):
    length = [0]
    for shape in matrix_shapes:
        length.append(np.prod(shape))
    return length

def filter_vector_in(matrices_as_vector, percentile_window, alpha):
    upper_perc = np.percentile(matrices_as_vector, 50 + percentile_window)
    lower_perc = np.percentile(matrices_as_vector, 50 - percentile_window)
    for i in range(len(matrices_as_vector)):
        if (matrices_as_vector[i] <= upper_perc and matrices_as_vector[i] >= lower_perc):
            matrices_as_vector[i] = alpha
    return matrices_as_vector

def filter_vector_out(matrices_as_vector, percentile_window, alpha):
    upper_perc = np.percentile(matrices_as_vector, 100 - percentile_window)
    lower_perc = np.percentile(matrices_as_vector, 0 + percentile_window)
    for i in range(len(matrices_as_vector)):
        if (matrices_as_vector[i] > upper_perc or matrices_as_vector[i] < lower_perc):
            matrices_as_vector[i] = alpha
    return matrices_as_vector

def test_network(network_matrices):
    out = "%g"%accuracy.eval(feed_dict={
        x: mnist.test.images, y_: mnist.test.labels, 
        W_conv1: network_matrices[0],
        W_conv2: network_matrices[1],
        W_fc1: network_matrices[2],
        keep_prob: 1.0})
    print(out)

In [62]:
x = tf.placeholder("float", shape=[None, 784])
y_ = tf.placeholder("float", shape=[None, 10])

W_conv1 = weight_variable([5, 5, 1, 32])
b_conv1 = bias_variable([32])

x_image = tf.reshape(x, [-1, 28, 28, 1])

h_conv1 = tf.nn.relu(conv2d(x_image, W_conv1) + b_conv1)
h_pool1 = max_pool_2x2(h_conv1)

W_conv2 = weight_variable([5, 5, 32, 64])
b_conv2 = bias_variable([64])

h_conv2 = tf.nn.relu(conv2d(h_pool1, W_conv2) + b_conv2)

h_pool2 = max_pool_2x2(h_conv2)

W_fc1 = weight_variable([7 * 7 * 64, 10])
b_fc1 = bias_variable([10])

h_pool2_flat = tf.reshape(h_pool2, [-1, 7*7*64])
keep_prob = tf.placeholder("float")
y_conv = tf.nn.softmax(tf.matmul(h_pool2_flat, W_fc1) + b_fc1)

predicted = tf.argmax(y_conv, 1)
actual = tf.argmax(y_, 1)
correct_prediction = tf.equal(predicted, actual)

accuracy = tf.reduce_mean(tf.cast(correct_prediction, "float"))
sess.run(tf.initialize_all_variables())

In [6]:
# import already trained model/network  

saver = tf.train.Saver()
ckpt = tf.train.get_checkpoint_state(os.getcwd())
saver.restore(sess, ckpt.model_checkpoint_path)

In [7]:
# Converts weight matrices from tensorflow tensor objects to real value numpy arrays with sess.run
# packs them together in a list to contain all of the weight matrices.

matrices_to_damage =\
    [np.asarray(sess.run(W_conv1)),
     np.asarray(sess.run(W_conv2)),
     np.asarray(sess.run(W_fc1))]
    

In [66]:
damage_amount = alpha_zero()
damaged_network = damage_network(matrices_to_damage, 0.0, damage_amount)

in_ = filter_network(matrices_to_damage, 1, 0, "inside")
out = filter_network(matrices_to_damage, 1, 0, "outside")
test_network(in_)
test_network(out)

83360
81692
83360
81692
0.9871
0.7797
