In [1]:
"""
1. Data generator
    a. Loads vocab
    c. Loads image features
    d. provide data for training.
2. Build image caption model
3. Trains the model
"""

import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import _pickle as cPickle
import numpy as np
import math


input_description_file = './data/results.token'
input_img_feature_dir = './data/download_inpcetion_v3_features/'
input_vocab_file = './data/vocab.txt'
output_dir = './data/local_run'

if not gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)
    
def get_default_params():
    return tf.contrib.training.HParams(
        num_vocab_word_threshold = 3,
        num_embedding_size = 32,
        num_timesteps = 10,
        num_lstm_nodes = [64,64],
        num_lstm_layers = 2,
        num_fc_nodes = 32,
        batch_size = 80, 
        cell_type = 'lstm',
        clip_lstm_grads = 1, # 梯度剪切，超过的会被设置成1
        learning_rate = 0.001,
        keep_prob = 0.8,
        log_frequent = 100,  # 多久打印一次log
        save_frequent = 1000,
    )
hps = get_default_params()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [9]:
# 载入词表
class Vocab:
    def __init__(self, filename, word_num_threshold):
        self._id_to_word = {}
        self._word_to_id = {}
        self._unk = -1
        self._eos = -1
        self._word_num_threshold = word_num_threshold
        self._read_dict(filename)
        
    def _read_dict(self, filename):
        with gfile.GFile(filename, 'r') as f:
            lines = f.readlines()
        
        for line in lines:
            word, occurrence = line.strip('\r\n').split('\t')
            occurrence = int(occurrence)
            if occurrence < self._word_num_threshold:
                continue
            idx = len(self._id_to_word)
            if word == '<UNK>':
                self._unk = idx
            elif word == '.':
                self._eos = idx
            if word in self._word_to_id or idx in self._id_to_word:
                raise Exception('duplicate words in vocab')
            self._id_to_word[idx] = word
            self._word_to_id[word] = idx
    
    @property
    def unk(self):
        return self._unk
    
    @property
    def eos(self):
        return self._eos
    
    def word_to_id(self, word):
        return self._word_to_id.get(word, self._unk)
            
    def id_to_word(self, word_id):
        return self._id_to_word.get(word_id, '<UNK>')
            
    def size(self):
        return len(self._id_to_word)
    
    # 输入句子转换成词的id列表
    def encode(self, sentence):
        return [self.word_to_id(word) for word in sentence.split(' ')]
             
    # 输入id列表转换成一句话
    def decode(self, sentence_id):
        words = [self.id_to_word(word_id) for word_id in sentence_id]
        return ' '.join(words)
        
# 测试
vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)

encode_sentence = vocab.encode('i have a dream')
print(encode_sentence)
print(vocab.decode(encode_sentence))

[3835, 389, 1, 0]
i have a <UNK>


In [11]:
# 解析图片描述文件，返回(img_name, [descriptions.....])
def parse_token_file(token_file):
    """parse image description file"""
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        image_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(image_name, [])
        img_name_to_tokens[image_name].append(description)
    return img_name_to_tokens

def convert_token_to_id(img_name_to_tokens, vocab):
    """Converts token of each description of imgs to id"""
    img_name_to_tokens_id = {}
    for img_name in img_name_to_tokens:
        img_name_to_tokens_id.setdefault(img_name, [])
        for description in img_name_to_tokens[img_name]:
            token_ids = vocab.encode(description)
            img_name_to_tokens_id[img_name].append(token_ids)
    return img_name_to_tokens_id

img_name_to_tokens = parse_token_file(input_description_file)
img_name_to_tokens_id = convert_token_to_id(img_name_to_tokens, vocab)

logging.info("num of all images: %d" % len(img_name_to_tokens))
pprint.pprint(img_name_to_tokens['1000268201.jpg'])
logging.info("num of all images: %d" % len(img_name_to_tokens_id))
pprint.pprint(img_name_to_tokens_id['1000268201.jpg'])

INFO:tensorflow:num of all images: 31783
['A child in a pink dress is climbing up a set of stairs in an entry way .',
 'A little girl in a pink dress going into a wooden cabin .',
 'A little girl climbing the stairs to her playhouse .',
 'A little girl climbing into a wooden playhouse .',
 'A girl going into a wooden building .']
INFO:tensorflow:num of all images: 31783
[[3, 52, 4, 1, 91, 117, 8, 247, 49, 1, 366, 10, 414, 4, 27, 5350, 670, 2],
 [3, 60, 30, 4, 1, 91, 117, 356, 71, 1, 227, 3610, 2],
 [3, 60, 30, 247, 5, 414, 15, 40, 3834, 2],
 [3, 60, 30, 247, 71, 1, 227, 3834, 2],
 [3, 30, 356, 71, 1, 227, 78, 2]]


In [None]:
class ImageCaptionData:
    """provide data for image caption model."""
    def __init__(self, image)