## LSTM Captioning

This is a very basic model: 

*  Take the featurized images (2048d), and tokenised captions
*  Add a (trainable) features -> 50d dense layer
*  Use a 50d GloVe embedding (for the LSTM inputs, non-trainable)
   *  #stop-words ~ 150 (say)
*  50d of hidden units for the LSTM
*  But have a 'pluggable' output transform :
   *   Concat : (256 one-hot - including '0'=mask, '1'={UNK}, '2'={START}, '3'={STOP}, '4'={UseOther})
   *   (a) UseOther + (8192-250 of more one-hot)
   *   (b) UseOther + (50d of same GloVe embedding, for nearest-neighbour)
   *   (c) UseOther + (log2(8192)==13 bits + error correction of index of word)
*  Want to monitor some kind of score over time for test cases   

In [None]:
import os

import numpy as np

import random
import pickle

TRAIN_PCT=0.9

In [None]:
# Load in the captions/corpus/embedding
with open('./data/cache/CAPTIONS_data_Flickr30k_2017-06-07_23-15.pkl', 'rb') as f:
    text_data=pickle.load(f, encoding='iso-8859-1')

"""
text_data ~ dict(
    img_to_captions = img_to_valid_captions,
    
    action_words = action_words, 
    stop_words = stop_words_sorted,
    
    embedding = embedding,
    embedding_word_arr = embedding_word_arr,
    
    img_arr = img_arr_save,
    train_test = np.random.random( (len(img_arr_save),) ),
)"""

embedding = text_data['embedding']
vocab_arr = text_data['embedding_word_arr']
dictionary = { w:i for i,w in enumerate(vocab_arr) }

img_arr_train = [ img for i, img in enumerate(text_data['img_arr']) if text_data['train_test'][i]<TRAIN_PCT ]

print("Loaded captions, corpus and embedding")

In [None]:
# Load in the features
with open('./data/cache/FEATURES_data_Flickr30k_flickr30k-images_2017-06-06_18-07.pkl', 'rb') as f:
    image_data=pickle.load(f, encoding='iso-8859-1')

"""
image_data ~ dict(
    features = features,
    img_arr = img_arr,
)
"""
image_feature_idx = { img:idx for idx, img in enumerate(image_data['img_arr']) }

print("Loaded image features for all images")

In [None]:
CAPTION_LEN = 32
EMBEDDING_DIM = embedding.shape[1]

VOCAB_SIZE = len(vocab_arr)
LOG2_VOCAB_SIZE = 13  # 1024->10, 8192->13
if not (2**LOG2_VOCAB_SIZE/2) < VOCAB_SIZE < 2**LOG2_VOCAB_SIZE:
    print("LOG2_VOCAB_SIZE incorrect")

In [None]:
def caption_to_idx_arr(caption):  # This is actually 1 longer than CAPTION_LEN - need to shift about a bit later
    ret = np.zeros( (CAPTION_LEN+1,), dtype='int32')  # {MASK}.idx===0
    ret[0] = dictionary['{START}']
    for i, w in enumerate( caption.lower().split() ):
        ret[i+1] = dictionary.get(w, dictionary['{UNK}'])
    ret[i+2] = dictionary['{STOP}']
    return ret

In [None]:
#for j in range(0,10):
#    print(j)
#print(j)    

In [None]:
def caption_training_example():
    img_arr = img_arr_train
    while True:
        random.shuffle( img_arr )
        for img in img_arr:
            captions = text_data['img_to_captions'][img]
            caption = random.choice(captions)
            print(caption)
            yield image_feature_idx[ img ], caption_to_idx_arr( caption )
        print("Captions : Looping")
caption_training_example_gen = caption_training_example()

In [None]:
next(caption_training_example_gen)

### TensorFlow / Keras imports

In [None]:
import tensorflow.contrib.keras
from tensorflow.contrib.keras.python.keras.utils.np_utils import to_categorical
from tensorflow.contrib.keras.api.keras.losses import cosine_proximity, categorical_crossentropy, mean_squared_error
from tensorflow.contrib.keras.api.keras.activations import softmax, sigmoid
from tensorflow.contrib.keras.api.keras.models import Model


### Create pluggable IO stages for words

In [None]:
class RepresentAs_FullEmbedding():
    width = EMBEDDING_DIM
    
    def encode(caption_arr):
        # plain embedding of each symbol
        return embedding[ caption_arr, : ]

    def loss_fn(ideal_output, network_output):  # y_true, y_pred
        return cosine_proximity( ideal_output, network_output )
    
class RepresentAs_FullOneHot():
    width = VOCAB_SIZE
    
    def encode(caption_arr):
        # Output desired is one-hot of each symbol (cross entropy match whole thing) 
        return to_categorical(caption_arr, num_classes=VOCAB_SIZE) 
    
    def loss_fn(ideal_output, network_output):  # y_true, y_pred
        smx = softmax(network_output, axis=-1)
        return categorical_crossentropy( ideal_output, smx )

In [None]:
base_width = len(text_data['action_words']) + len(text_data['stop_words'])
MASK_idx = dictionary['{MASK}']   # ==0
EXTRA_idx = dictionary['{EXTRA}']

def OneHotBasePlus(arr): # Contains indexing magic
    #  Arrangement should be :: (samples, timesteps, features),
    one_hot_base_plus = np.zeros( (CAPTION_LEN, base_width), dtype='float32')
    # Set the indicator for entries not in action or stop words
    one_hot_base_plus[ arr>=base_width, EXTRA_idx ] = 1.0
    # Set the one-hot for everthing in the one-hot-region
    one_hot_base_plus[ arr< base_width, arr[np.where(arr<base_width)] ] = 1.0
    # Force masked values to all-zeros
    one_hot_base_plus[ arr==0, MASK_idx ] = 0.0
    return one_hot_base_plus

class RepresentAs_OneHotBasePlusEmbedding():
    width = base_width + EMBEDDING_DIM

    def encode(caption_arr): 
        # Input is onehot for first part, with the embedding included for all words too
        return np.hstack( [ OneHotBasePlus(caption_arr), embedding[caption_arr] ] )
    
    def loss_fn(ideal_output, network_output):  # y_true, y_pred
        # One-hot of each action symbol and stop words (cross entropy match these) and 
        #   RMSE on remaining embedding (weighted according to onehot[{EXTRA}]~0...1)
        
        # Perhaps need this idea https://github.com/fchollet/keras/issues/890:
        smx = softmax(network_output[:base_width], axis=-1)
        
        is_extra = smx[:, EXTRA_idx]
        one_hot_loss = categorical_crossentropy( ideal_output[:base_width], smx )    
        embedding_loss = cosine_proximity( ideal_output[base_width:], 
                                                network_output[base_width:] )
        return (1.-is_extra)*one_hot_loss + (is_extra)*embedding_loss
    
class RepresentAs_OneHotBasePlusBinaryIdx():
    width = base_width + 3*LOG2_VOCAB_SIZE
    powers_of_two = 2**np.arange(LOG2_VOCAB_SIZE)

    def encode(caption_arr):
        # Input is onehot for first part, with 3 copies of the binary index of all words afterwards
        #   Idea is from : https://arxiv.org/abs/1704.06918

        # Thanks to : https://stackoverflow.com/questions/21918267/
        #         convert-decimal-range-to-numpy-array-with-each-bit-being-an-array-element
        binary = (caption_arr[:, np.newaxis] & powers_of_two) / powers_of_two
        
        return np.hstack( [ OneHotBasePlus(caption_arr), binary, binary, binary ] )
  
    def loss_fn(ideal_output, network_output):  # y_true, y_pred
        smx = softmax(network_output[:base_width], axis=-1)
        sig = sigmoid(network_output[base_width:])
        
        is_extra = smx[:, EXTRA_idx]
        one_hot_loss = categorical_crossentropy( ideal_output[:base_width], smx )
        #binary_loss  = categorical_crossentropy( ideal_output[base_width:], sig )
        binary_loss  = mean_squared_error( ideal_output[base_width:], sig )  # reported better in paper
        return (1.-is_extra)*one_hot_loss + (is_extra)*binary_loss
    


In [None]:
#io = RepresentAs_FullEmbedding             # .encode : 32, 50
#io = RepresentAs_FullOneHot                # .encode : 32, 6946
io = RepresentAs_OneHotBasePlusEmbedding   # .encode : 32, 191 
#io = RepresentAs_OneHotBasePlusBinaryIdx   # .encode : 32, 180
io.width

caption_sample = 'The cat sat on the mat .'
caption_sample_idx = caption_to_idx_arr(caption_sample)
caption_sample_idx  # array([   2,    8, 1461, 2496,   11,    8,  998,    5,    3,    0,    0...

onehot_start=range(0,12)
x=OneHotBasePlus(caption_sample_idx[:-1])  # Just the first 10 one-hot entries
#x
#x[onehot_start, dictionary['{MASK}']]
#x[onehot_start, dictionary['{START}']]
#x[onehot_start,  EXTRA_idx]
#x[onehot_start, dictionary['{STOP}']]
#x[onehot_start, dictionary['on']]

#x.shape  # 32, 141 
#embedding[caption_sample_idx[:-1]].shape  # 32, 50

if False:
    powers_of_two = 2**np.arange(LOG2_VOCAB_SIZE)
    (caption_sample_idx[:, np.newaxis] & powers_of_two) / powers_of_two

io.encode( caption_sample_idx[:-1] ).shape  # [0:6,:]

In [None]:
# See : https://keras.io/layers/core/#masking
#model = Sequential()
#model.add(Masking(mask_value=0., input_shape=(timesteps, features)))
#model.add(LSTM(32))