# Import

In [None]:
import os
import numpy as np
import random
import time
from PIL import Image
import tensorflow as tf
import tensorflow.contrib.slim as slim
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3
from tensorflow.contrib.slim.python.slim.nets.inception_v3 import inception_v3_arg_scope

# Build graph

In [None]:
g = tf.get_default_graph()

## Define preprocessing

In [None]:
with tf.name_scope('Preprocessing'):
    input_im = tf.placeholder(tf.uint8, shape=[None, None, None, 3], name='input_im')
    resized_im = tf.image.resize_bilinear(tf.image.convert_image_dtype(
        tf.convert_to_tensor(input_im) ,dtype=tf.float32), [299,299], name='resized_im')
    normalized_im = tf.mul(tf.sub(resized_im, 0.5), 2.0, name='normalized_im')

## Load Inception-v3 graph from slim

In [None]:
with slim.arg_scope(inception_v3_arg_scope()):
    logits, end_points = inception_v3(normalized_im, num_classes=1001, is_training=False)
slim_variables = [v for v in tf.global_variables() if v.name.startswith('InceptionV3/')]

## Create own branch

In [None]:
with tf.name_scope('Own'):
    y = tf.placeholder(dtype="float", shape=[None,300], name='y')
    batch_size = tf.to_float(tf.shape(y, name='batch_size')[0])
    x = g.get_tensor_by_name('InceptionV3/Logits/SpatialSqueeze:0')
    w = tf.Variable(tf.random_normal([int(x.get_shape()[-1]),300], stddev=float('1e-5')), name='weights')
    b = tf.Variable(tf.random_normal([1,300]), name='bias')
    y_pred = tf.add(tf.matmul(x, w), b, name='y_pred')
    cost = tf.divide(tf.reduce_sum(tf.square(y-y_pred)), batch_size, name='cost')
    cost_summary = tf.summary.scalar('cost', cost)

## Define training

In [None]:
with tf.name_scope('Optimization'):
    global_step = tf.Variable(0, trainable=False, name='global_step')
    init_learning_rate = tf.constant(float('1e-5'), name='init_learning_rate')
    learning_rate = tf.train.exponential_decay(
        init_learning_rate, global_step, 200, 0.95, staircase=True, name='learning_rate')

    optimizer = tf.train.RMSPropOptimizer(learning_rate=learning_rate, momentum=0.9, name='Optimizer')
    
    variables_to_train = [w,b]
    variables_to_train.extend([v for v in tf.trainable_variables() if
                               v.name.startswith('InceptionV3/Logits') and
                               (v.name.endswith('weights:0') or v.name.endswith('biases:0'))])
    variables_to_train.extend([v for v in tf.trainable_variables() if
                               v.name.startswith('InceptionV3/Mixed_7c') and
                               (v.name.endswith('weights:0') or v.name.endswith('beta:0'))])
    
    train_step = optimizer.minimize(cost, var_list=variables_to_train, global_step=global_step, name='train_step')

## Save graph

In [None]:
timestamp = time.strftime('%Y%m%d-%H%M', time.localtime())
tf.summary.FileWriter('./runs/'+timestamp+'/graph/', graph=g)

# Read data

In [None]:
# Images
train_image_path = '../../../data/train2014/'
train_image_list = os.listdir(train_image_path)
val_image_path = '../../../data/val2014/'
val_image_list = os.listdir(val_image_path)

# Caption dictionaries
train_dict = np.load('../../../data/word2vec_train.npy').item()
val_dict = np.load('../../../data/word2vec_val.npy').item()

# Labels
labels = open('../../../data/ilsvrc2012_labels.txt', 'r').readlines()

# Train model

In [None]:
# Tried to follow
# https://github.com/tensorflow/models/tree/master/slim
# http://stackoverflow.com/questions/35274457/inceptionv3-and-transfer-learning-with-tensorflow/40709836#40709836
with tf.Session() as sess:
    
    # Initialize variables
    saver = tf.train.Saver(var_list=slim_variables)
    saver.restore(sess, "../../../data/inception-2016-08-28/inception_v3.ckpt")
    sess.run([w.initializer, b.initializer])
    sess.run([global_step.initializer])
    uninitialized = sess.run(tf.report_uninitialized_variables())
    uninitialized = [str(v)[2:-1]+':0' for v in uninitialized]
    uninitialized = [v for v in tf.global_variables() if v.name in uninitialized]
    sess.run(tf.variables_initializer(var_list=uninitialized))
    
    # Initialize writers
    train_writer = tf.summary.FileWriter('./runs/'+timestamp+'/sums/train/', flush_secs=20)
    val_writer = tf.summary.FileWriter('./runs/'+timestamp+'/sums/val/', flush_secs=20)
    
    # Train
    print('TRAINING')
    batch_size = 5
    val_set_size = 10
    for i in range(len(train_image_list))[:10000:batch_size]:
        image_batch = []
        caption_batch = []
        # Form batch and feed through
        for j in range(batch_size):
            temp_im = np.array(Image.open(train_image_path + train_image_list[i+j]))
            # If image has only one channel
            if len(np.shape(temp_im)) == 2:
                temp_im = np.stack([temp_im,temp_im,temp_im],axis=2)
            # Resize image
            resized_temp_im = sess.run(g.get_tensor_by_name('Preprocessing/Sub:0'),{input_im:[temp_im]})
            image_batch.append(resized_temp_im[0])
        
            # Choose one of the five captions randomly
            r = random.randrange(len(train_dict[train_image_list[i+j]]))
            caption_batch.append(train_dict[train_image_list[i+j]][r])

        caption_batch = np.stack(caption_batch, axis=0)
        image_batch = np.stack(image_batch, axis=0)
        [batch_cost, _] = sess.run([cost_summary, train_step], feed_dict={input_im:image_batch, y:caption_batch})
        train_writer.add_summary(batch_cost, i)
        print(i, end=' ')

#        # Get prediction
#        class_pred = np.argmax(sess.run(g.get_tensor_by_name('InceptionV3/Predictions/Softmax:0'),{input_im:image}))
#        print(labels[class_pred])

        # Evaluate every 25th step
        if i % 25 == 0:
            val_image_batch = []
            val_caption_batch = []
            for j in range(val_set_size):
                temp_im = np.array(Image.open(val_image_path + val_image_list[j]))
                # If image has only one channel
                if len(np.shape(temp_im)) == 2:
                    temp_im = np.stack([temp_im,temp_im,temp_im],axis=2)
                # Resize image
                resized_temp_im = sess.run(g.get_tensor_by_name('Preprocessing/Sub:0'),{input_im:[temp_im]})
                val_image_batch.append(resized_temp_im[0])

                # Use first caption in validation set
                val_caption_batch.append(val_dict[val_image_list[j]][0])
                
            val_caption_batch = np.stack(val_caption_batch, axis=0)
            val_image_batch = np.stack(val_image_batch, axis=0)
            val_cost = sess.run(cost_summary, feed_dict={input_im:val_image_batch, y:val_caption_batch})
            val_writer.add_summary(val_cost, i)
            
    # Save
    os.mkdir('./runs/'+timestamp+'/checkpoint/')
    new_saver = tf.train.Saver()
    save_path = new_saver.save(sess, './runs/'+timestamp+'/checkpoint/model.ckpt')
    print('\n\nTrained model saved to {}'.format(save_path))