In [None]:
import tensorflow as tf
from utils.models import *
from main import TextGen
import tqdm
import keras
%pylab inline

def rescale_img(img):
    img = tf.cast(img, tf.float32)
    img -= 127.5
    img /= 127.5
    return img

def string_to_onehots(string, chars=128, pad_length=128):
    split_caption = tf.string_split(tf.expand_dims(string, 0), delimiter='')
    table = tf.contrib.lookup.index_table_from_tensor(
        mapping=[chr(i) for i in range(chars)], num_oov_buckets=0)
    indices = table.lookup(split_caption.values)
    pad_inds = tf.pad(indices[:pad_length], 
                      [[0, tf.maximum(0, pad_length-tf.shape(indices)[0])]])
    pad_inds.set_shape((pad_length,))
    #return pad_inds
    embeddings = tf.eye(chars)
    code = tf.nn.embedding_lookup(embeddings, pad_inds)
    code.set_shape((pad_length, chars))
    return code

def read_data(record_file='/home/paperspace/data/ms_coco/train.tfrecord'):
    ''''''
    reader = tf.TFRecordReader()
    _, serialized_example = reader.read(
        tf.train.string_input_producer([record_file]))
    feature = {'image': tf.FixedLenFeature([], tf.string),
               'caption': tf.FixedLenFeature([], tf.string),
               'class': tf.FixedLenFeature([], tf.int64)}
    features = tf.parse_single_example(serialized_example, features=feature)
    img, caption, class_ = features['image'], features['caption'], features['class']
    
    decoded_img = rescale_img(tf.image.decode_jpeg(img))
    onehot_caption = string_to_onehots(caption)
    return decoded_img, onehot_caption, class_

In [None]:
batch_size = 256
capacity = 8
n_threads = 8

backend.set_learning_phase(True)
example = read_data()
example[0].set_shape((64, 64, 3))
#example[1].set_shape((128,))
img, c, class_ = tf.train.shuffle_batch_join(
    [example]*n_threads, batch_size=batch_size, 
    capacity=batch_size*capacity, min_after_dequeue=0)

m = TextGen()
c_hat = m.forward_pass(img)
#L = tf.nn.cross_entropy_with_logits(labels=c, logits=c_hat)
L = tf.reduce_mean(keras.losses.categorical_crossentropy(c, c_hat))
#L = tf.reduce_mean((c-c_hat)**2)
lstm_opt = tf.train.AdamOptimizer(1e-3)\
             .minimize(L, var_list=m.lstm.trainable_weights)
#cnn_opt = tf.train.AdamOptimizer(1e-5, beta1=0.8)\
#             .minimize(L, var_list=m.phi.trainable_weights)
    
with tf.Session() as sess:
    sess.run([tf.global_variables_initializer(), tf.tables_initializer()])
    coord = tf.train.Coordinator()
    tf.train.start_queue_runners(sess=sess, coord=coord)
    loss = []
    m.phi.load_weights('/home/paperspace/.keras/models/mobilenet_1_0_224_tf_no_top.h5')
    for n in tqdm.trange(1000000, disable=False):
        
        #c_eval, c_hat_eval = sess.run([c, c_hat])
        #break
        
        _, l = sess.run([lstm_opt, L])
        if n > 0 and n % 100 == 0:
            loss.append(l)
            c_eval, c_hat_eval = sess.run([c, c_hat])

In [None]:
plot(loss)

In [None]:
imshow(c_hat_eval[10])