In [1]:
from configuration import Config
from utils import CharaterTable, Preprocessor, vetorize_caption
from CaptionModel import CaptionModel
import numpy as np

Using TensorFlow backend.


In [2]:
class Caption(object):
    """Represents a complete or partial caption."""

    def __init__(self, sentence, state, logprob, score, metadata=None):
        """Initializes the Caption.

        Args:
          sentence: List of word ids in the caption.
          state: Model state after generating the previous word.
          logprob: Log-probability of the caption.
          score: Score of the caption.
          metadata: Optional metadata associated with the partial sentence. If not
            None, a list of strings with the same length as 'sentence'.
        """
        self.sentence = sentence
        self.state = state
        self.logprob = logprob
        self.score = score
        self.metadata = metadata

    def __cmp__(self, other):
        """Compares Captions by score."""
        assert isinstance(other, Caption)
        if self.score == other.score:
            return 0
        elif self.score < other.score:
            return -1
        else:
            return 1
    # For python3 compatibility
    def __lt__(self, other):
        assert isinstance(other, Caption)
        return self.score < other.score

    def __eq__(self, other):
        assert isinstance(other, Caption)
        return self.score == other.score


class TopN(object):
    """Maintains the top n elements of an incrementally provided set."""

    def __init__(self, n):
        self._n = n
        self._data = []

    def size(self):
        assert self._data is not None
        return len(self._data)

    def push(self, x):
        """Pushes a new element."""
        assert self._data is not None
        if len(self._data) < self._n:
            heapq.heappush(self._data, x)
        else:
            heapq.heappushpop(self._data, x)

    def extract(self, sort=False):
        """Extracts all elements from the TopN. This is a destructive operation.

        The only method that can be called immediately after extract() is reset().

        Args:
            sort: Whether to return the elements in descending sorted order.

        Returns:
            A list of data; the top n elements provided to the set.
        """
        assert self._data is not None
        data = self._data
        self._data = None
        if sort:
            data.sort(reverse=True)
        return data

        def reset(self):
            """Returns the TopN to an empty state."""
            self._data = []

In [3]:
config = Config()
data = Preprocessor(config)
ctable = CharaterTable(data.train_captions + data.val_captions)
caption_len = 25

# Y_train = vetorize_caption(data.train_captions, ctable, caption_len)
# Y_val = vetorize_caption(data.val_captions, ctable, caption_len)

caption_model = CaptionModel(
    image_len = data.image_len,
    caption_len = caption_len,
    vocab_size = ctable.vocab_size)

In [4]:
caption_model.build_inference_model('./checkpoint/weights.014-0.747.hdf5', beam_search=True)

In [5]:
image = data.train_set[0]
labels = data.train_captions[0][0]
print image.shape
print ''.join([i for i in labels]) 

(4096,)
大街的马路上有一个路标指向牌


In [6]:
image_output = np.expand_dims(caption_model.image_model.predict_on_batch(image[None,...]), axis=1)

In [7]:
caption_model.caption_model.reset_states()

In [8]:
predict = caption_model.caption_model.predict_on_batch(image_output)

In [15]:
state = caption_model.caption_model.ge

In [None]:
state = caption_model.caption_model.reset_states

In [None]:
state = caption_model.caption_model.set_weights

In [13]:
state

[(<tensorflow.python.ops.variables.Variable at 0x7fbfa4591610>,
  <tf.Tensor 'gru_1/while/Exit_2:0' shape=(1, 128) dtype=float32>)]