In [1]:
import tensorflow as tf
from tensorflow import keras
import numpy as np

In [2]:
fashion_mnist = keras.datasets.mnist

(train_images, train_labels), (test_images, test_labels) = fashion_mnist.load_data()
train_shape = train_images.shape
test_shape = test_images.shape

y_shape = len(set(train_labels))

train_images_reshaped = train_images.reshape(train_shape[0],train_shape[1]*train_shape[2])/255.0
test_images_reshaped = test_images.reshape(test_shape[0],train_images_reshaped.shape[1])/255.0

In [3]:
hidden_sizes = [1000,1000,500,200]

In [4]:
tf.reset_default_graph()

with tf.device("/GPU:0"):
    x_placeholder = tf.placeholder(tf.float32,shape = (None,train_shape[1]*train_shape[2]))
    y_placeholder = tf.placeholder(tf.int64,shape=(None))
    output_layer_weights =  tf.Variable( tf.truncated_normal([hidden_sizes[-1],y_shape]))


    weights_matrices = []
    layer_outputs = []
    weights_magnitudes_matrices = []
    top_indices_list = []

    layer_outputs.append(x_placeholder)

    for layer in range(len(hidden_sizes)):
        if layer == 0:

            prev_units = int(x_placeholder.shape[1]) 
        else:
            prev_units = int(weights_matrices[layer-1].shape[1])

        layer_units = hidden_sizes[layer]

        print(layer+1, (prev_units,layer_units))
        weights_matrix_l = tf.Variable( tf.truncated_normal ([prev_units,layer_units]))
        layer_output =  tf.nn.relu( tf.matmul(layer_outputs[layer],weights_matrix_l))

        weights_magnitudes = tf.abs(weights_matrix_l)
        top_values, top_indices = tf.nn.top_k(tf.reshape(-weights_magnitudes, (-1,)), 5)
        top_indices = tf.stack(((top_indices // weights_magnitudes.shape[1]), (top_indices % weights_magnitudes.shape[1])), -1)
        
        weights_matrices.append(weights_matrix_l)
        layer_outputs.append(layer_output)
        weights_magnitudes_matrices.append(weights_magnitudes)
        top_indices_list.append(top_indices)

    last_layer_output =    tf.matmul(  layer_output , output_layer_weights)

    print(layer+2,  (hidden_sizes[-1],y_shape))

    loss =  tf.reduce_mean(
        tf.nn.sparse_softmax_cross_entropy_with_logits(labels = y_placeholder,logits = last_layer_output))
    optimizer_step = tf.train.AdamOptimizer(0.01).minimize(loss)

    last_layer_hard_output = tf.argmax(last_layer_output,axis=1)
    accuracy = tf.reduce_mean( tf.cast( tf.equal(last_layer_hard_output,y_placeholder),tf.float32))


    init = tf.global_variables_initializer()

1 (784, 1000)
2 (1000, 1000)
3 (1000, 500)
4 (500, 200)
5 (200, 10)


In [5]:
EPOCHS = 100
PRINT_EVERY = 10

In [6]:
config = tf.ConfigProto(allow_soft_placement = True)
sess = tf.InteractiveSession(config=config)

sess.run(init)
feed_dict = {x_placeholder:train_images_reshaped,y_placeholder:train_labels}

for epoch in range(1,EPOCHS + 1):
    _,epoch_loss,epoch_output,epoch_hard_output,epoch_acc  = sess.run([optimizer_step,loss,last_layer_output,last_layer_hard_output,accuracy],feed_dict)
    
    if epoch % PRINT_EVERY == 0:
        print("Epoch :{} loss:{}  epoch acc:{}".format(epoch,epoch_loss,epoch_acc))
        

trained_weights =  sess.run([layer  for layer in  weights_matrices],feed_dict=feed_dict)


Epoch :10 loss:100499.8515625  epoch acc:0.6105999946594238
Epoch :20 loss:18595.79296875  epoch acc:0.8187333345413208
Epoch :30 loss:10010.990234375  epoch acc:0.8848999738693237
Epoch :40 loss:7122.7490234375  epoch acc:0.901283323764801
Epoch :50 loss:4923.2392578125  epoch acc:0.9178500175476074
Epoch :60 loss:3573.817626953125  epoch acc:0.927566647529602
Epoch :70 loss:2728.211181640625  epoch acc:0.9352333545684814
Epoch :80 loss:2100.83740234375  epoch acc:0.9449666738510132
Epoch :90 loss:1627.613037109375  epoch acc:0.9542666673660278
Epoch :100 loss:1261.8719482421875  epoch acc:0.9607833623886108


In [7]:
test_feed_dict = {x_placeholder:test_images_reshaped,y_placeholder:test_labels}
test_acc = sess.run(accuracy,feed_dict=test_feed_dict)

print("Test acc",test_acc)


Test acc 0.9327


In [9]:
sess.run(top_indices_list[0],feed_dict=feed_dict)


array([[247, 901],
       [736,  28],
       [363, 637],
       [ 52, 256],
       [623, 556]], dtype=int32)