# Import

In [1]:
import os
import numpy as np
import tensorflow as tf
import random
from gensim.models.word2vec import Word2Vec
import gc
import time
import pickle

# Load graph

In [2]:
with tf.gfile.FastGFile(os.path.join(
    '../../../data/inception-2015-12-05/classify_image_graph_def.pb'), 'rb') as f:
    graph_def = tf.GraphDef()
    graph_def.ParseFromString(f.read())
    _ = tf.import_graph_def(graph_def, name='')
g = tf.get_default_graph()

## Create own branch

In [3]:
with tf.name_scope('own'):
    y = tf.placeholder("float")
    x = g.get_tensor_by_name('pool_3/_reshape:0')
    w = tf.Variable(tf.random_normal(
        [int(x.get_shape()[-1]),300], stddev=float('1e-5')), name='weights')
    b = tf.Variable(tf.random_normal([1,300]), name='bias')
    y_pred = tf.add(tf.matmul(x, w), b, name='y_pred')
    cost = tf.reduce_sum(tf.square(y-y_pred),name='cost')
    cost_summary = tf.summary.scalar('cost', cost)
    optimizer = tf.train.GradientDescentOptimizer(learning_rate=float('1e-5')).minimize(cost)

# Read data

In [4]:
# Images
image_path = '../../../data/train2014/'
image_list = os.listdir(image_path)

# Caption dictionaries
vec_dict = np.load('../../../data/word2vec_train.npy').item()
#with open('../../../data/caption_train.pkl', 'rb') as f:
#    string_dict = pickle.load(f)

# Train model

In [6]:
with tf.Session(graph=g) as sess:
    # Initialize writers
    timestamp = time.strftime('%Y%m%d-%H%M', time.localtime())
    train_writer = tf.summary.FileWriter('./sums/train/'+timestamp, graph=g, flush_secs=20)
    test_writer = tf.summary.FileWriter('./sums/test/'+timestamp, graph=g, flush_secs=20)
    
    sess.run([w.initializer, b.initializer])
    
    # Train
    print('TRAINING')
    for i, im in enumerate(image_list[:1000]):
        image = tf.gfile.FastGFile(image_path + im, 'rb').read()
        # Choose one of the five captions randomly
        r = random.randrange(len(vec_dict[im]))
        y_temp = vec_dict[im][r].reshape((1,300))

        sess.run(optimizer,feed_dict={'DecodeJpeg/contents:0':image, y:y_temp})
        train_writer.add_summary(sess.run(cost_summary,{'DecodeJpeg/contents:0':image, y:y_temp}), i)
        
        if i % 10 == 0:
            test_image = tf.gfile.FastGFile(image_path + image_list[-1], 'rb').read()
            y_test = vec_dict[image_list[-1]][1].reshape((1,300))
            test_writer.add_summary(sess.run(cost_summary,{'DecodeJpeg/contents:0':test_image, y:y_test}), i)
            print(i, end=' ')
    print('\n\n')
    
    # Save
    print('SAVING NEW DICTIONARY')
    new_vec_dict = {}
    
#    # Multi-threaded
#    def get_vec(im):
#        image = tf.gfile.FastGFile('../../../data/train2014/' + im, 'rb').read()
#        v = tf.get_default_session().run(y_pred,{'DecodeJpeg/contents:0':image})
#        global count
#        with count.get_lock():
#            count.value += 1
#            print(count.value)
#        return {im: v}
#    from multiprocessing import Pool, Value
#    with Pool(2) as pool:
#        count = Value("i", 0)
#        vec_map = pool.map(get_vec, image_list[-1000:])
#        for j in vec_map:
#            new_vec_dict.update(j)
            
    # Single-threaded
    for i, im in enumerate(image_list[1000:6000]):
        image = tf.gfile.FastGFile(image_path + im, 'rb').read()
        new_vec_dict[im] = sess.run(y_pred,{'DecodeJpeg/contents:0':image})
        if i % 10 == 0:
            print(i, end=' ')
    
    np.save('image_space.npy',new_vec_dict)
        

TRAINING
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 710 720 730 740 750 760 770 780 790 800 810 820 830 840 850 860 870 880 890 900 910 920 930 940 950 960 970 980 990 


SAVING NEW DICTIONARY
0 10 20 30 40 50 60 70 80 90 100 110 120 130 140 150 160 170 180 190 200 210 220 230 240 250 260 270 280 290 300 310 320 330 340 350 360 370 380 390 400 410 420 430 440 450 460 470 480 490 500 510 520 530 540 550 560 570 580 590 600 610 620 630 640 650 660 670 680 690 700 710 720 730 740 750 760 770 780 790 800 810 820 830 840 850 860 870 880 890 900 910 920 930 940 950 960 970 980 990 1000 1010 1020 1030 1040 1050 1060 1070 1080 1090 1100 1110 1120 1130 1140 1150 1160 1170 1180 1190 1200 1210 1220 1230 1240 1250 1260 1270 1280 1290 1300 1310 1320 1330 1340 1350 1360 137