# Flat Image Net - Visualize embedding

In [None]:
%load_ext autoreload
%autoreload 2
%matplotlib inline

In [None]:
import tensorflow as tf
from tensorflow.examples.tutorials.mnist import input_data
import numpy as np
import os
import shutil
import pandas as pd
import matplotlib.pylab as plt
import seaborn as sns
from PIL import Image
from utils.data import init_dir

## Load data

In [None]:
data = input_data.read_data_sets('/data/fashion/', one_hot=True)
img_shape = (28, 28)
class_id2class_name_mapping = {
    0: 'T-shirt/top',
    1: 'Trouser',
    2: 'Pullover',
    3: 'Dress',
    4: 'Coat',
    5: 'Sandal',
    6: 'Shirt',
    7: 'Sneaker',
    8: 'Bag',
    9: 'Ankle boot'}

## Build Net Graph

In [None]:
from utils.nn_graph import simple_layer
from utils.nn_visualization import variable_summaries, img_summaries

graph = tf.Graph()
with graph.as_default():
    with tf.name_scope('flat_image_net_inputs'):
        images = tf.placeholder(tf.float32, shape=[None, 784], name='images')
        labels = tf.placeholder(tf.float32, shape=[None, 10], name='labels')
        keep_dropout_prob = tf.placeholder(tf.float32, name='keep_dropout_prob')

    
    with tf.variable_scope('simple_layer_1'):
        raw_prediction = simple_layer(name='layer1', input_data=images, shape=[784, 64], activation='relu')
        
    with tf.variable_scope('simple_layer_2'):
        raw_prediction = simple_layer(name='layer2', input_data=raw_prediction, shape=[64, 10])
        raw_prediction = tf.nn.dropout(raw_prediction, keep_dropout_prob)
            
    with tf.name_scope('prediction'):
        prediction = tf.nn.softmax(raw_prediction)
    
    with tf.name_scope('loss'):
        cross_entropy_vector = tf.nn.softmax_cross_entropy_with_logits(labels=labels, logits=prediction)
        loss = tf.reduce_mean(cross_entropy_vector)
        variable_summaries('loss_summary', cross_entropy_vector)

    with tf.name_scope('accuracy'):
        correct_prediction = tf.equal(tf.argmax(prediction,1), tf.argmax(labels,1))
        correct_prediction = tf.cast(correct_prediction, tf.float32)
        accuracy = tf.reduce_mean(correct_prediction)
        variable_summaries('accuracy_summary', correct_prediction)       
        
    with tf.name_scope('training'):
        train_step = tf.train.AdamOptimizer(0.001).minimize(loss)

    with tf.name_scope("embedding_visualization"):
        embedding = tf.Variable(tf.zeros([5000, 10]), name='valid_embedding')
        embedding_assignment = embedding.assign(raw_prediction)
            
    initialize_vars = tf.global_variables_initializer()
    merge_summaries = tf.summary.merge_all()

## Init Model Logging

In [None]:
from utils.data import init_model_logging
base_dir = '/tensorboard_summaries/flat_image_net/'
exp_name = 'experiment_visual_embedding'

logging_meta = init_model_logging(base_dir, exp_name, graph=graph, remove_existing=True)

## Data for Embedding Projection

#### Sprite Img

In [None]:
def get_sprite_img(images, img_shape):
    image_cout = len(images)
    h, w = img_shape[:2]
    
    rows = int(np.ceil(np.sqrt(image_cout)))
    cols = rows
    
    if len(img_shape) == 3:
        sprite_img = np.zeros([rows*h, cols*w, img_shape[2]])
    else:
        sprite_img = np.zeros([rows*h, cols*w])
        
    image_id = 0
    for row_id in range(rows):
        for col_id in range(cols):
            if image_id >= image_cout:
                if len(img_shape) == 3:
                    sprite_img = Image.fromarray(np.uint8(sprite_img))
                else:
                    sprite_img = Image.fromarray(np.uint8(sprite_img * 0xFF))        
                return sprite_img

            ##########################################
            # Fill in images into the sprite image,  #
            # use row order and reshape to img_shape #
            ##########################################
            image_id += 1

In [None]:
sprite_img = get_sprite_img(data.validation.images, img_shape)
sprite_img

#### Label Class Names

In [None]:
def get_label_class_names(label_class_onehots, class_id2class_name_mapping):
    return [class_id2class_name_mapping[c_id] for c_id in np.argmax(label_class_onehots, axis=1).tolist()]

In [None]:
label_class_names = get_label_class_names(data.validation.labels, class_id2class_name_mapping)
label_class_names

## Add Embedding Projection

In [None]:
sprite_img = get_sprite_img(data.validation.images, img_shape)
label_names = get_label_class_names(data.validation.labels, class_id2class_name_mapping)

In [None]:
from utils.nn_visualization import init_embedding_projector, init_embedding_data

init_embedding_data(logging_meta['valid_writer_dir'], sprite_img, label_names)
init_embedding_projector(logging_meta['valid_writer'], embedding, img_shape)

## Run Net

In [None]:
config = tf.ConfigProto(allow_soft_placement=True)
config.gpu_options.allow_growth = True
model_path = logging_meta['model_path']
        
with tf.Session(graph=graph, config=config) as session:
    session.run(initialize_vars)
    for iteration in range(10000):
        ##################
        # Training phase #
        ##################
        _images, _labels = data.train.next_batch(100)
        _ = session.run([train_step], feed_dict={images: _images, labels: _labels, keep_dropout_prob: 0.5})
        if iteration % 10 == 0:
            _summary, _accuracy, _loss = session.run([merge_summaries, accuracy, loss],
                                                     feed_dict={images: _images, 
                                                                labels: _labels, 
                                                                keep_dropout_prob: 1.0})
            logging_meta['train_writer'].add_summary(_summary, iteration)
            print("Train Iteration {}: loss {}, accuracy {}".format(iteration, _loss, _accuracy))
      
        ####################
        # Validation phase #
        ####################
        if iteration % 100 == 0:
            _, _summary, _accuracy, _loss = session.run([embedding_assignment, merge_summaries, accuracy, loss], 
                                          feed_dict={images: data.validation.images, 
                                                     labels: data.validation.labels,
                                                     keep_dropout_prob: 1.0})
            logging_meta['valid_writer'].add_summary(_summary, iteration)
            logging_meta['saver'].save(session, model_path, iteration)
            print("= Valid Iteration {}: loss {}, accuracy {} =".format(iteration, _loss, _accuracy))
            
    _prediction, = session.run([prediction], feed_dict={images: data.validation.images, keep_dropout_prob: 1.0})