# Flat Image Net - Basic Graph


In [1]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [2]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import shutil
from utils.data import init_dir

## Load data

In [3]:
data = input_data.read_data_sets('/data/fashion/', one_hot=True)
class_id2class_name_mapping = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'}

Extracting /data/fashion/train-images-idx3-ubyte.gz
Extracting /data/fashion/train-labels-idx1-ubyte.gz
Extracting /data/fashion/t10k-images-idx3-ubyte.gz
Extracting /data/fashion/t10k-labels-idx1-ubyte.gz


## Basic summary functions

In [4]:
def variable_summaries(name, var):
    with tf.name_scope(name):        
        tf.summary.histogram('histogram', var)
        ###########################################################
        # Complete scalar summaries for min, max, stddev and mean #
        ###########################################################
        
def img_summaries(name, var):
    with tf.name_scope(name):
        tf.summary.image(name, var)

## Custom layer functions

In [5]:
def simple_layer(name, input_data, shape, activation='linear'):
    w_name = 'w_' + name
    b_name = 'b_' + name
    
    w = tf.get_variable('w', initializer=tf.truncated_normal(shape, stddev=0.1))
    bias = tf.get_variable(b_name, initializer=tf.constant_initializer(0.1), shape=shape[1])
    
    #####################################################
    # Complete xavier and variance_scaling initializers #
    #####################################################
    
    output_data = tf.matmul(input_data, w) + bias
    
    #################################################
    # Complete sigmoid, linear and relu activations #
    #################################################
    
    return output_data

## Build Net Graph

In [6]:
graph = tf.Graph()
with graph.as_default():
    with tf.name_scope('flat_image_net_inputs'):
        images = tf.placeholder(tf.float32, shape=[None, 784], name='images')
        labels = tf.placeholder(tf.float32, shape=[None, 10], name='labels')
    
    with tf.variable_scope('simple_layer_1'):
        raw_prediction = simple_layer(name='layer1', input_data=images, shape=[784, 10])
            
    with tf.name_scope('prediction'):
        prediction = raw_prediction
    
    with tf.name_scope('loss'):
        #####################
        # Fix loss function #
        #####################
        loss_vector = raw_prediction - labels
        loss = tf.reduce_mean(loss_vector)
        variable_summaries('loss_summary', loss_vector)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))
        correct_prediction = tf.cast(correct_prediction, tf.float32)
        accuracy = tf.reduce_mean(correct_prediction)
        variable_summaries('accuracy_summary', correct_prediction)       
        
    with tf.name_scope('training'):
        ####################
        # Change optimizer #
        ####################
        train_step = tf.train.GradientDescentOptimizer(0.5).minimize(loss)
            
    initialize_vars = tf.global_variables_initializer()
    merge_summaries = tf.summary.merge_all()

## Init Model Logging

In [7]:
from utils.data import init_model_logging
base_dir = '/tensorboard_summaries/flat_image_net/'

logging_meta = init_model_logging(base_dir, 'experiment1', graph=graph, remove_existing=True)
######################################
# Inspect code of init_model_logging #
######################################

## Run Net

In [8]:
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
model_path = logging_meta['model_path']

        
with tf.Session(graph=graph, config=config) as session:
    session.run(initialize_vars)
    for iteration in range(10000):
        _images, _labels = data.train.next_batch(100)
        
        _ = session.run([train_step], feed_dict={images: _images, labels: _labels})
        if iteration % 10 == 0:
            _summary, _accuracy, _loss = session.run([merge_summaries, accuracy, loss],
                                                     feed_dict={images: _images, labels: _labels})
            logging_meta['train_writer'].add_summary(_summary, iteration)
            print("Iteration {}: loss {}, accuracy {}".format(iteration, _loss, _accuracy))
        
        #####################################
        # Add validation section with saver #
        #####################################
            
    _prediction, = session.run([prediction], feed_dict={images: data.validation.images})

Iteration 0: loss -4.3117265701293945, accuracy 0.03999999910593033
Iteration 10: loss -54.31332778930664, accuracy 0.05999999865889549
Iteration 20: loss -99.51949310302734, accuracy 0.14000000059604645
Iteration 30: loss -139.197265625, accuracy 0.11999999731779099
Iteration 40: loss -200.31491088867188, accuracy 0.05000000074505806
Iteration 50: loss -229.6427764892578, accuracy 0.12999999523162842
Iteration 60: loss -300.5534973144531, accuracy 0.14000000059604645
Iteration 70: loss -341.6572570800781, accuracy 0.10999999940395355
Iteration 80: loss -374.0713806152344, accuracy 0.10999999940395355
Iteration 90: loss -383.85693359375, accuracy 0.10000000149011612
Iteration 100: loss -488.0667419433594, accuracy 0.10000000149011612
Iteration 110: loss -539.5288696289062, accuracy 0.07999999821186066
Iteration 120: loss -575.18408203125, accuracy 0.009999999776482582
Iteration 130: loss -626.6466064453125, accuracy 0.12999999523162842
Iteration 140: loss -662.1748046875, accuracy 0.15

Iteration 1550: loss -7382.7177734375, accuracy 0.12999999523162842
Iteration 1560: loss -7717.5087890625, accuracy 0.11999999731779099
Iteration 1570: loss -7635.95263671875, accuracy 0.05000000074505806
Iteration 1580: loss -7152.02685546875, accuracy 0.03999999910593033
Iteration 1590: loss -7490.79248046875, accuracy 0.10999999940395355
Iteration 1600: loss -7730.37890625, accuracy 0.14000000059604645
Iteration 1610: loss -7850.544921875, accuracy 0.12999999523162842
Iteration 1620: loss -7620.0400390625, accuracy 0.09000000357627869
Iteration 1630: loss -7875.8310546875, accuracy 0.11999999731779099
Iteration 1640: loss -7611.96044921875, accuracy 0.12999999523162842
Iteration 1650: loss -7661.23681640625, accuracy 0.07999999821186066
Iteration 1660: loss -8337.4306640625, accuracy 0.12999999523162842
Iteration 1670: loss -8174.4013671875, accuracy 0.12999999523162842
Iteration 1680: loss -8300.212890625, accuracy 0.14000000059604645
Iteration 1690: loss -8387.02734375, accuracy 0

Iteration 2910: loss -14267.1220703125, accuracy 0.10000000149011612
Iteration 2920: loss -13860.505859375, accuracy 0.07999999821186066
Iteration 2930: loss -13109.4296875, accuracy 0.12999999523162842
Iteration 2940: loss -13716.53125, accuracy 0.05999999865889549
Iteration 2950: loss -13432.6064453125, accuracy 0.15000000596046448
Iteration 2960: loss -14799.5654296875, accuracy 0.05000000074505806
Iteration 2970: loss -13761.197265625, accuracy 0.07999999821186066
Iteration 2980: loss -14863.63671875, accuracy 0.14000000059604645
Iteration 2990: loss -14983.73828125, accuracy 0.05000000074505806
Iteration 3000: loss -13887.353515625, accuracy 0.1599999964237213
Iteration 3010: loss -13979.0439453125, accuracy 0.14000000059604645
Iteration 3020: loss -15206.4736328125, accuracy 0.12999999523162842
Iteration 3030: loss -14173.0830078125, accuracy 0.09000000357627869
Iteration 3040: loss -14546.89453125, accuracy 0.10000000149011612
Iteration 3050: loss -13793.9697265625, accuracy 0.1

Iteration 4370: loss -21156.87890625, accuracy 0.10999999940395355
Iteration 4380: loss -20921.1953125, accuracy 0.10000000149011612
Iteration 4390: loss -20621.802734375, accuracy 0.07999999821186066
Iteration 4400: loss -21739.935546875, accuracy 0.10000000149011612
Iteration 4410: loss -22307.060546875, accuracy 0.10999999940395355
Iteration 4420: loss -22732.927734375, accuracy 0.15000000596046448
Iteration 4430: loss -20397.900390625, accuracy 0.07999999821186066
Iteration 4440: loss -21387.796875, accuracy 0.10000000149011612
Iteration 4450: loss -19922.064453125, accuracy 0.09000000357627869
Iteration 4460: loss -21500.736328125, accuracy 0.12999999523162842
Iteration 4470: loss -20916.337890625, accuracy 0.05999999865889549
Iteration 4480: loss -22431.0625, accuracy 0.09000000357627869
Iteration 4490: loss -21270.818359375, accuracy 0.019999999552965164
Iteration 4500: loss -20820.46484375, accuracy 0.10000000149011612
Iteration 4510: loss -21350.46484375, accuracy 0.0799999982

Iteration 5880: loss -26180.26171875, accuracy 0.11999999731779099
Iteration 5890: loss -27344.16796875, accuracy 0.07999999821186066
Iteration 5900: loss -26452.677734375, accuracy 0.11999999731779099
Iteration 5910: loss -26210.5859375, accuracy 0.15000000596046448
Iteration 5920: loss -27270.73046875, accuracy 0.10999999940395355
Iteration 5930: loss -27018.25390625, accuracy 0.12999999523162842
Iteration 5940: loss -26414.16015625, accuracy 0.11999999731779099
Iteration 5950: loss -29689.66796875, accuracy 0.14000000059604645
Iteration 5960: loss -28485.63671875, accuracy 0.12999999523162842
Iteration 5970: loss -29803.451171875, accuracy 0.05999999865889549
Iteration 5980: loss -26419.8046875, accuracy 0.10000000149011612
Iteration 5990: loss -28067.474609375, accuracy 0.09000000357627869
Iteration 6000: loss -28026.435546875, accuracy 0.09000000357627869
Iteration 6010: loss -27660.783203125, accuracy 0.15000000596046448
Iteration 6020: loss -25932.427734375, accuracy 0.090000003

Iteration 7280: loss -34766.15234375, accuracy 0.07999999821186066
Iteration 7290: loss -36341.2578125, accuracy 0.12999999523162842
Iteration 7300: loss -35569.0703125, accuracy 0.15000000596046448
Iteration 7310: loss -36628.9453125, accuracy 0.14000000059604645
Iteration 7320: loss -32526.02734375, accuracy 0.09000000357627869
Iteration 7330: loss -35911.30078125, accuracy 0.09000000357627869
Iteration 7340: loss -34462.359375, accuracy 0.05999999865889549
Iteration 7350: loss -31122.390625, accuracy 0.09000000357627869
Iteration 7360: loss -35086.18359375, accuracy 0.10999999940395355
Iteration 7370: loss -35226.66796875, accuracy 0.03999999910593033
Iteration 7380: loss -36285.3046875, accuracy 0.09000000357627869
Iteration 7390: loss -33594.5, accuracy 0.10000000149011612
Iteration 7400: loss -36830.625, accuracy 0.10000000149011612
Iteration 7410: loss -35313.12890625, accuracy 0.10999999940395355
Iteration 7420: loss -35641.44140625, accuracy 0.05999999865889549
Iteration 7430:

Iteration 8610: loss -37884.5, accuracy 0.07999999821186066
Iteration 8620: loss -42918.6015625, accuracy 0.05000000074505806
Iteration 8630: loss -44098.76953125, accuracy 0.11999999731779099
Iteration 8640: loss -40813.0703125, accuracy 0.07999999821186066
Iteration 8650: loss -43719.83984375, accuracy 0.03999999910593033
Iteration 8660: loss -41191.5859375, accuracy 0.14000000059604645
Iteration 8670: loss -39461.46484375, accuracy 0.12999999523162842
Iteration 8680: loss -40633.11328125, accuracy 0.12999999523162842
Iteration 8690: loss -40395.54296875, accuracy 0.05999999865889549
Iteration 8700: loss -40058.0703125, accuracy 0.07999999821186066
Iteration 8710: loss -41010.59375, accuracy 0.12999999523162842
Iteration 8720: loss -38569.546875, accuracy 0.11999999731779099
Iteration 8730: loss -41718.34375, accuracy 0.12999999523162842
Iteration 8740: loss -42471.484375, accuracy 0.05999999865889549
Iteration 8750: loss -43475.796875, accuracy 0.07000000029802322
Iteration 8760: lo

In [9]:
_prediction

array([[-60869.6484375 , -60867.18359375, -60867.23828125, ...,
        -60870.0625    , -60866.8046875 , -60867.671875  ],
       [-67426.75      , -67423.7109375 , -67425.5859375 , ...,
        -67427.921875  , -67424.65625   , -67426.2890625 ],
       [-26674.76367188, -26673.16015625, -26673.80664062, ...,
        -26675.41015625, -26673.44335938, -26673.84179688],
       ..., 
       [-16572.75390625, -16571.66992188, -16571.40234375, ...,
        -16571.68554688, -16571.8046875 , -16572.54101562],
       [-63961.3125    , -63958.421875  , -63958.625     , ...,
        -63962.51953125, -63958.21484375, -63960.3203125 ],
       [-28546.58984375, -28545.76757812, -28545.37109375, ...,
        -28547.55078125, -28545.515625  , -28546.046875  ]], dtype=float32)