In [1]:
"""
1. Data generator
    a. Loads vocab
    c. Loads image features
    d. provide data for training.
2. Build image caption model
3. Trains the model
"""

import os
import sys
import tensorflow as tf
from tensorflow import gfile
from tensorflow import logging
import pprint
import _pickle as cPickle
import numpy as np
import math
import random


input_description_file = './data/results.token'
input_img_feature_dir = './data/download_inpcetion_v3_features/'
input_vocab_file = './data/vocab.txt'
output_dir = './data/local_run'

if not gfile.Exists(output_dir):
    gfile.MakeDirs(output_dir)
    
def get_default_params():
    return tf.contrib.training.HParams(
        num_vocab_word_threshold = 3,
        num_embedding_nodes = 32,
        num_timesteps = 10,
        num_lstm_nodes = [64,64],
        num_lstm_layers = 2,
        num_fc_nodes = 32,
        batch_size = 80, 
        cell_type = 'lstm',
        clip_lstm_grads = 1, # 梯度剪切，超过的会被设置成1
        learning_rate = 0.001,
        keep_prob = 0.8,
        log_frequent = 10, # 100,  # 多久打印一次log
        save_frequent = 100, # 1000,
    )
hps = get_default_params()


For more information, please see:
  * https://github.com/tensorflow/community/blob/master/rfcs/20180907-contrib-sunset.md
  * https://github.com/tensorflow/addons
If you depend on functionality not listed there, please file an issue.



In [2]:
# 载入词表
class Vocab:
    def __init__(self, filename, word_num_threshold):
        self._id_to_word = {}
        self._word_to_id = {}
        self._unk = -1
        self._eos = -1
        self._word_num_threshold = word_num_threshold
        self._read_dict(filename)
        
    def _read_dict(self, filename):
        with gfile.GFile(filename, 'r') as f:
            lines = f.readlines()
        
        for line in lines:
            word, occurrence = line.strip('\r\n').split('\t')
            occurrence = int(occurrence)
            if occurrence < self._word_num_threshold:
                continue
            idx = len(self._id_to_word)
            if word == '<UNK>':
                self._unk = idx
            elif word == '.':
                self._eos = idx
            if word in self._word_to_id or idx in self._id_to_word:
                raise Exception('duplicate words in vocab')
            self._id_to_word[idx] = word
            self._word_to_id[word] = idx
    
    @property
    def unk(self):
        return self._unk
    
    @property
    def eos(self):
        return self._eos
    
    def word_to_id(self, word):
        return self._word_to_id.get(word, self._unk)
            
    def id_to_word(self, word_id):
        return self._id_to_word.get(word_id, '<UNK>')
            
    def size(self):
        return len(self._id_to_word)
    
    # 输入句子转换成词的id列表
    def encode(self, sentence):
        return [self.word_to_id(word) for word in sentence.split(' ')]
             
    # 输入id列表转换成一句话
    def decode(self, sentence_id):
        words = [self.id_to_word(word_id) for word_id in sentence_id]
        return ' '.join(words)
        
# 测试
vocab = Vocab(input_vocab_file, hps.num_vocab_word_threshold)
vocab_size = vocab.size()
encode_sentence = vocab.encode('i have a dream')
print(encode_sentence)
print(vocab.decode(encode_sentence))

[3835, 389, 1, 0]
i have a <UNK>


In [3]:
# 解析图片描述文件，返回(img_name, [descriptions.....])
def parse_token_file(token_file):
    """parse image description file"""
    img_name_to_tokens = {}
    with gfile.GFile(token_file, 'r') as f:
        lines = f.readlines()
    
    for line in lines:
        img_id, description = line.strip('\r\n').split('\t')
        image_name, _ = img_id.split('#')
        img_name_to_tokens.setdefault(image_name, [])
        img_name_to_tokens[image_name].append(description)
    return img_name_to_tokens

def convert_token_to_id(img_name_to_tokens, vocab):
    """Converts token of each description of imgs to id"""
    img_name_to_tokens_id = {}
    for img_name in img_name_to_tokens:
        img_name_to_tokens_id.setdefault(img_name, [])
        for description in img_name_to_tokens[img_name]:
            token_ids = vocab.encode(description)
            img_name_to_tokens_id[img_name].append(token_ids)
    return img_name_to_tokens_id

img_name_to_tokens = parse_token_file(input_description_file)
img_name_to_tokens_id = convert_token_to_id(img_name_to_tokens, vocab)

logging.info("num of all images: %d" % len(img_name_to_tokens))
pprint.pprint(img_name_to_tokens['1000268201.jpg'])
logging.info("num of all images: %d" % len(img_name_to_tokens_id))
pprint.pprint(img_name_to_tokens_id['1000268201.jpg'])

INFO:tensorflow:num of all images: 31783
['A child in a pink dress is climbing up a set of stairs in an entry way .',
 'A little girl in a pink dress going into a wooden cabin .',
 'A little girl climbing the stairs to her playhouse .',
 'A little girl climbing into a wooden playhouse .',
 'A girl going into a wooden building .']
INFO:tensorflow:num of all images: 31783
[[3, 52, 4, 1, 91, 117, 8, 247, 49, 1, 366, 10, 414, 4, 27, 5350, 670, 2],
 [3, 60, 30, 4, 1, 91, 117, 356, 71, 1, 227, 3610, 2],
 [3, 60, 30, 247, 5, 414, 15, 40, 3834, 2],
 [3, 60, 30, 247, 71, 1, 227, 3834, 2],
 [3, 30, 356, 71, 1, 227, 78, 2]]


In [4]:
class StrToBytes:
    def __init__(self, fileobj):
        self.fileobj = fileobj
    def read(self, size):
        return self.fileobj.read(size).encode()
    def readline(self, size=-1):
        return self.fileobj.readline(size).encode()

class ImageCaptionData:
    """provide data for image caption model."""
    def __init__(self,
                img_name_to_tokens_id, # 图片描述
                img_feature_dir,
                num_timesteps,
                vocab,
                deterministic = False):
        self._vocab = vocab
        self._img_name_to_tokens_id = img_name_to_tokens_id
        self._num_timesteps = num_timesteps
        self._deterministic = deterministic
        self._indicator = 0
        
        self._img_feature_filenames = []
        self._img_feature_data = []
        
        self._all_img_feature_filepaths = []
        for filename in gfile.ListDirectory(img_feature_dir):
            self._all_img_feature_filepaths.append(
                os.path.join(img_feature_dir, filename))
        pprint.pprint(self._all_img_feature_filepaths)
        # 载入特征文件
        self._load_img_feature_pickle()
        
        if not self._deterministic:
            self._random_shuffle()
        
    def _load_img_feature_pickle(self):
        """load image feature data from pickle files"""
        for filepath in self._all_img_feature_filepaths:
            logging.info('loading %s' % filepath)
            with gfile.GFile(filepath, 'rb') as f:
                filenames, features = cPickle.load(f)
                self._img_feature_filenames += filenames # 不用append.是因为filenames是一个列表。用+=可以合并两个列表
                self._img_feature_data.append(features)
        # [#(1000, 1, 1, 2048), #(1000, 1, 1, 2048)] -> [#(2000, 1, 1, 2048)]
        self._img_feature_data = np.vstack(self._img_feature_data) # 合并矩阵
        origin_shape = self._img_feature_data.shape
        # 做一个reshape将中间的两个1去掉
        self._img_feature_data = np.reshape(self._img_feature_data, (origin_shape[0], origin_shape[3]))
        # 将filenames也转换成一个numpy矩阵
        self._img_feature_filenames = np.asarray(self._img_feature_filenames)
        print(self._img_feature_data.shape)
        print(self._img_feature_filenames.shape)
        
    def size(self):
        return len(self._img_feature_filenames)
    
    def img_feature_size(self):
        return self._img_feature_data.shape[1]
    
    def _random_shuffle(self):
        """shuffle data randomly"""
        p = np.random.permutation(self.size())
        self._img_feature_data = self._img_feature_data[p]
        self._img_feature_filenames = self._img_feature_filenames[p]
    
    def _img_desc(self, batch_filenames):
        """Gets description for filename in batch"""
        # 通过filenames知道对应的描述，然后对这些描述进行截断或者补全
        batch_sentence_ids = []
        batch_weights = []
        for filename in batch_filenames:
            token_ids_set = self._img_name_to_tokens_id[filename]
            # 从获取的描述中随机选一个
            chosen_token_ids = random.choice(token_ids_set)
            chosen_token_ids_length = len(chosen_token_ids)
            
            weight = [1 for i in range(chosen_token_ids_length)]
            if chosen_token_ids_length > self._num_timesteps:
                # 做截断
                chosen_token_ids = chosen_token_ids[0:self._num_timesteps]
                weight = weight[0:self._num_timesteps]
            else:
                # 做填充
                remaining_length = self._num_timesteps - chosen_token_ids_length
                # 使用eos进行填充
                chosen_token_ids += [self._vocab.eos for i in range(remaining_length)]
                weight += [0 for i in range(remaining_length)] # 使用0填充weight
            batch_sentence_ids.append(chosen_token_ids)
            batch_weights.append(weight)
        batch_sentence_ids = np.asarray(batch_sentence_ids)
        batch_weights = np.asarray(batch_weights)
        return batch_sentence_ids, batch_weights
    
    def next_batch(self, batch_size):
        """Returns next batch data"""
        end_indicator = self._indicator + batch_size
        if end_indicator > self.size():
            if not self._deterministic:
                self._random_shuffle()
            self._indicator = 0
            end_indicator = self._indicator + batch_size
        assert end_indicator < self.size()
        
        batch_filenames = self._img_feature_filenames[self._indicator : end_indicator]
        batch_img_features = self._img_feature_data[self._indicator : end_indicator]
        # sentence id: [100, 101, 102, 10, 3, 0, 0, 0] -> [1,1,1,1,1,0,0,0]
        batch_sentence_ids, batch_weights = self._img_desc(batch_filenames)
        self._indicator = end_indicator
        return batch_img_features, batch_sentence_ids, batch_weights, batch_filenames
    
caption_data = ImageCaptionData(img_name_to_tokens_id, 
                               input_img_feature_dir,
                               hps.num_timesteps,
                               vocab)

img_feature_dim = caption_data.img_feature_size()
caption_data_size = caption_data.size()
logging.info('img_feature_dim: %d' % img_feature_dim)
logging.info('caption_data_size: %d' % caption_data_size)

batch_img_features, batch_sentence_ids, batch_weights, batch_img_names = caption_data.next_batch(5)
pprint.pprint(batch_img_features)
pprint.pprint(batch_sentence_ids)
pprint.pprint(batch_weights)
pprint.pprint(batch_img_names)

['./data/feature_extraction_inception_v3/image_features-0.pickle',
 './data/feature_extraction_inception_v3/image_features-1.pickle',
 './data/feature_extraction_inception_v3/image_features-10.pickle',
 './data/feature_extraction_inception_v3/image_features-11.pickle',
 './data/feature_extraction_inception_v3/image_features-12.pickle',
 './data/feature_extraction_inception_v3/image_features-13.pickle',
 './data/feature_extraction_inception_v3/image_features-14.pickle',
 './data/feature_extraction_inception_v3/image_features-15.pickle',
 './data/feature_extraction_inception_v3/image_features-16.pickle',
 './data/feature_extraction_inception_v3/image_features-17.pickle',
 './data/feature_extraction_inception_v3/image_features-18.pickle',
 './data/feature_extraction_inception_v3/image_features-19.pickle',
 './data/feature_extraction_inception_v3/image_features-2.pickle',
 './data/feature_extraction_inception_v3/image_features-20.pickle',
 './data/feature_extraction_inception_v3/image_feat

TypeError: readline() takes 1 positional argument but 2 were given

In [None]:
 # 返回循环神经网络的单个结构
def create_rnn_cell(hidden_dim, cell_type):
    """return specific cell according to rnn type"""
    if cell_type == 'lstm':
        return tf.contrib.rnn.BasicLSTMCell(hidden_dim, state_is_tuple = True)
    elif cell_type == 'gru':
        return tf.contrib.rnn.GRUCell(hidden_dim)
    else:
        raise Exception("%s type is not been supported" % cell_type)

# 封装dropout
def dropout(cell, keep_prob):
    """wrap cell with dropout"""
    return tf.contrib.rnn.DropoutWrapper(cell, output_keep_prob=keep_prob)
    
# 构建计算图
def get_train_model(hps, vocab_size, img_feature_dim):
    """
    args:
    - hps: 参数
    - vocab_size: 用在embedding, inference
    - img_feature_dim: 图像特征维度
    """
    num_timesteps = hps.num_timesteps
    batch_size = hps.batch_size
    
    
    # define placeholder
    img_feature = tf.placeholder(tf.float32, (batch_size, img_feature_dim))
    sentence = tf.placeholder(tf.int32, (batch_size, num_timesteps))
    mask = tf.placeholder(tf.float32, (batch_size, num_timesteps))
    keep_prob = tf.placeholder(tf.float32, name="keep_prob")
    
    global_step = tf.Variable(tf.zeros([], tf.int32), name = 'global_step', trainable = False)
    
    """
    prediction process:
    sentence: [a, b, c, d, e, f]
    input: [img, a, b, c, d]
    img_feature: [0.4, 0.3, 0.2, 0.6]
    predict#1: img_feature -> embedding_img -> lstm -> (a)
    predict#2: a -> embedding_word -> lstm -> (b)
    predict#3: b -> embedding_word -> lstm -> (c)
    .....
    通常是将img得到的embedding_img和embedding_word合并再进行预测
    """
    
    # setup embedding layer
    embedding_initializer = tf.random_uniform_initializer(-1.0,1.0)
    with tf.variable_scope('embedding', initializer = embedding_initializer):
        embeddings = tf.get_variable(
            'embeddings', 
            [vocab_size, hps.num_embedding_nodes], tf.float32)
        # embed_token_ids: [batch_size, num_timestep -1, num_embedding_nodes]
        embed_token_ids = tf.nn.embedding_lookup(
            embeddings, 
            sentence[:, 0: num_timesteps - 1])
    
    # 定义一个全连接层，将图像特征变成跟分词相同的维度
    img_feature_embed_init = tf.uniform_unit_scaling_initializer(factor = 1.0)
    with tf.variable_scope('img_feature_embed', initializer = img_feature_embed_init):
        # img_feature: [batch_size, img_feature_dim]
        # embed_img: [batch_size, num_embedding_nodes]
        embed_img = tf.keras.layers.Dense(
            hps.num_embedding_nodes)(img_feature)
        # embed_img: [batch_size, 1, num_embedding_nodes]
        embed_img = tf.expand_dims(embed_img, 1)
        # 在第一维上合并embed_img和embed_token_ids
        # embed_inputs: [batch_size, num_timesteps, num_embedding_nodes]
        embed_inputs = tf.concat([embed_img, embed_token_ids], axis = 1)
        
    # setup rnn network
    scale = 1.0 / math.sqrt(hps.num_embedding_nodes + hps.num_lstm_nodes[-1]) / 3.0
    rnn_init = tf.random_uniform_initializer(-scale, scale)
    with tf.variable_scope('lstm_rnn', initializer = rnn_init):
        cells = []
        for i in range(hps.num_lstm_layers):
            cell = create_rnn_cell(hps.num_lstm_nodes[i], hps.cell_type)
            cell = dropout(cell, keep_prob)
            cells.append(cell)
        # 合并cell
        cell = tf.contrib.rnn.MultiRNNCell(cells)
        init_state = cell.zero_state(hps.batch_size, tf.float32)
        # rnn_outputs: [batch_size, num_timesteps, hps.num_lstm_nodes[-1]]
        rnn_outputs, _ = tf.nn.dynamic_rnn(cell,
                                          embed_inputs,
                                          initial_state = init_state)
        
    # setup fully connected layer
    fc_init = tf.uniform_unit_scaling_initializer(factor = 0.1)
    with tf.variable_scope('fc', initializer = fc_init):
        rnn_outputs_2d = tf.reshape(rnn_outputs, [-1, hps.num_lstm_nodes[-1]])
        
        fc1 = tf.keras.layers.Dense(hps.num_fc_nodes, name = 'fc1')(rnn_outputs_2d)
        fc1_dropout = tf.contrib.layers.dropout(fc1, keep_prob)
        fc1_relu = tf.nn.relu(fc1_dropout)
        # 计算概率
        logits = tf.keras.layers.Dense(vocab_size, name = 'logits')(fc1_relu)
        
        
    # calculate loss 
    with tf.variable_scope('loss'):
        # 展平 sentence and mask
        sentence_flatten = tf.reshape(sentence, [-1])
        mask_flatten = tf.reshape(mask, [-1])
        
        mask_sum = tf.reduce_sum(mask_flatten)
        
        # 计算损失
        softmax_loss = tf.nn.sparse_softmax_cross_entropy_with_logits(
            logits = logits,
            labels = sentence_flatten )
    
        # 排除weight为0的值
        weighted_softmax_loss = tf.multiply(softmax_loss, tf.cast(mask_flatten, tf.float32))
        loss = tf.reduce_sum(weighted_softmax_loss) / mask_sum
        
        prediction = tf.argmax(logits, 1, output_type=tf.int32)
        correct_prediction = tf.equal(prediction, sentence_flatten)
        
        weighted_correct_prediction = tf.multiply(tf.cast(correct_prediction, tf.float32), mask_flatten)
        
        accuracy = tf.reduce_sum(weighted_correct_prediction) / mask_sum
        tf.summary.scalar('loss', loss)
        
    # define train op
    with tf.variable_scope('train_op'):
        tvars = tf.trainable_variables()
        for var in tvars:
            logging.info('variable name: %s' % var.name)
        grads, _ = tf.clip_by_global_norm(
            tf.gradients(loss, tvars), hps.clip_lstm_grads)
        optimizer = tf.train.AdadeltaOptimizer(hps.learning_rate)
        train_op = optimizer.apply_gradients(
            zip(grads, tvars), global_step = global_step)
        
    return ((img_feature, sentence, mask, keep_prob),
           (loss, accuracy, train_op),
           global_step)

placeholders, metrics, global_step = get_train_model(
    hps, vocab_size, img_feature_dim)

img_feature, sentence, mask, keep_prob = placeholders
loss, accuracy, train_op = metrics

summary_op = tf.summary.merge_all()
init_op = tf.global_variables_initializer()

# 用来保存模型
saver = tf.train.Saver(max_to_keep = 10)

In [None]:
# 训练流程
train_steps = 1000

with tf.Session() as sess:
    sess.run(init_op)
    writer = tf.summary.FileWriter(output_dir, sess.graph)
    for i in range(train_steps):
        (batch_img_features, batch_sentences_ids, batch_weights, _) = caption_data.next_batch(hps.batch_size)
        input_vals = (batch_img_features, batch_sentences_ids, batch_weights, hps.keep_prob)
        feed_dict = dict(zip(placeholders, input_vals))
        fetches = [global_step, loss, accuracy, train_op]
        should_log = (i + 1) % hps.log_frequent == 0
        should_save = (i + 1) % hps.save_frequent == 0
        
        if should_log:
            fetches += [summary_op]
            
        outputs = sess.run(fetches, feed_dict = feed_dict)
        global_step_val, loss_val, accuracy_val = outputs[0:3]
        if should_log:
            summary_str = outputs[-1]
            writer.add_summary(summary_str, global_step_val)
            logging.info('Step: %5d, loss: %3.3f, accu: %3.3f' % (global_step_val, loss_val, accuracy_val))
            
        if should_save:
            model_save_file = os.path.join(output_dir, 'image_caption')
            logging.info('Step: %5d, model saved' % global_step_val)
            saver.save(sess, model_save_file, global_step = global_step_val)