# Import required modules

In [None]:
import tensorflow as tf
from keras.backend.tensorflow_backend import set_session
from keras import backend as K
from keras.layers import LSTM
from keras.models import Sequential, Model
from keras.layers import *
from keras.optimizers import RMSprop
from keras.regularizers import l2
from generator import Generator
from models import *
import h5py
import numpy as np
import sys
import os

# Set Tensorflow backend to avoid full GPU pre-loading

In [None]:
config = tf.ConfigProto()
config.gpu_options.allow_growth = True
set_session(tf.Session(config = config))

# Load data generator

This is required to set some variables during testing.

In [None]:
print "Loading Generator"
# Hardcoding because codes will mostly run from my account.
# Also adding sys.argv stuff with defaults is a pain.
gen = Generator(dataset_directory = '../data')

# Testing

## Load models requred for training

In [None]:
from IPython.display import Image
from keras.preprocessing import image
from keras.applications.inception_v3 import InceptionV3
from keras.applications.inception_v3 import preprocess_input

base_model = InceptionV3(weights = 'imagenet', include_top = True, input_shape = (299, 299, 3))
img_model =  Model(
        input = base_model.input,
        outputs = [base_model.get_layer('mixed10').output])
target_size = (299, 299)
print img_model.output_shape
output_shape = img_model.output_shape[1:]

## Load previously saved model

In [None]:
from keras.models import load_model

model = load_model('../data/models/attention_flickr_epoch_120.h5')

# Model Visualization

## Model summary (text based)

In [None]:
model.summary()

## Model summary (flow chart)

In [None]:
# Someone fill this

## Testing images individually
The following code randomly selects an image from validation dataset which is not used for training and displays the caption for the image.

In [None]:
image_filenames = ['3071676551_a65741e372.jpg']
dataset_directory = '../data/flicker8k'

from keras.backend.tensorflow_backend import get_session

preprocessed_images = []
number_of_images = len(image_filenames)
img_input = []

# Iterate over all images and preprocess them
for img_id, img_name in enumerate(image_filenames): # For coco make 3D array , do batch
    img_filepath = dataset_directory + '/Flickr8k_Dataset/' + img_name
    # Image preprocessing
    img = image.load_img(img_filepath, target_size = target_size)
    img = image.img_to_array(img)
    img = np.expand_dims(img, axis=0)
    img = preprocess_input(img)

    preprocessed_images.append(np.squeeze(img))

preprocessed_images = np.asarray(preprocessed_images)
img_features = img_model.predict(preprocessed_images)

text_in = np.zeros((1,gen.max_token_len))
text_in[0][0] = gen.token_to_id['<start>']

predictions = []
activations = []
for arg in range(gen.max_token_len-1):
    pred = model.predict([img_features, text_in])
    tok = np.argmax(pred[0][arg])
    word = gen.id_to_token[tok]
    text_in[0][arg+1] = tok
    if word == '<end>':
        break
    predictions.append(word)
predictions.append('.')
print ' '.join(predictions)
Image(filename= dataset_directory + '/Flickr8k_Dataset/' + image_filenames[0])

## Visualize the attention map for the above test

Run this only after running the above block

In [None]:
from matplotlib.pyplot import imshow
from matplotlib.pyplot import imread
from matplotlib import pyplot as plt
% matplotlib inline


for i in range(1, len(predictions) - 1):
    text_in[0][i] = gen.token_to_id[predictions[i]]

layer_name = 'time_distributed_6'
intermediate_layer_model = Model(inputs=model.input,
                                 outputs=model.get_layer(layer_name).output)

pred = intermediate_layer_model.predict([img_features, text_in])

plt.figure(figsize=(20,10))
columns = 5
I = plt.imread(dataset_directory + '/Flickr8k_Dataset/' + image_filenames[0])
print predictions[0]
for i in range(len(predictions) - 2):
    plt.subplot(len(predictions) / columns + 1, columns, i + 1)
    plt.imshow(I)
    att = pred[0,i,:].reshape((8,8), order = 'A')
    plt.imshow(att, alpha = 0.7, interpolation='bilinear', 
               cmap='gray', extent=[0, I.shape[1], I.shape[0], 0])
    plt.title(predictions[i+1])
    plt.axis('off')

## For generating test results in bulk
### Preprocessing
In process_images.py change the following as per requirements:
- dataset_directory --> Should be path to extracted dataset dreictory (like ~repo/data/COCO/extracted).
- img_list_file = 'val2014'--> Make sure to use val2014. Captions are not available for the test files in the dataset we have.
- save_name --> Name of the image features file to save (like test_features.h5).
- images_per_step --> Lower number for systems with lower RAM (this is your physical RAM not your GPU memory.
- batch_size --> Lower number for systems with lower GPU memory (VRAM if running on CPU).

In [None]:
from IPython.display import Image
import time
import progressbar

with open('../data/flicker8k/preprocessed/test_captions.txt') as captions_file:
    captions = captions_file.read().split('\n')
    
class Caption:
    def __init__(self, name):
        self.name = name
        self.captions = ['','','','','']
        self.result = ''
        
    def add(self, caption_number, caption):
        self.captions[caption_number] = caption
        
test_results = {}

print('Preprocessing results:')
bar = progressbar.ProgressBar(
        term_width = 56,
        max_value = len(captions),
        widgets = [
            progressbar.Counter(format='%(value)04d/%(max_value)d'),
            progressbar.Bar('=', '[', ']', '.'),
            ' ',
            progressbar.ETA()
            ])

for pos, caption in enumerate(captions):
    if len(caption) < 5:
        continue
    caption = caption.split('\t')
    img_name = caption[0].split('#')
    caption_number = int(img_name[1])
    img_name = img_name[0]
    caption = caption[1].lower()
    
    if pos % 25 == 0:
        bar.update(pos)
    
    try:
        cap_obj = test_results[img_name]
        cap_obj.add(caption_number, caption)
    except Exception as e:
#         print str(e)
        feature_dataset = h5py.File('../data/flicker8k/preprocessed/test_features.h5', 'r')
        img_features = feature_dataset[img_name]['cnn_features'][:]

        # image_filenames = get_image_filenames(dataset_directory + '/' + img_list_file)

        # print img_features.shape
        features = np.array([img_features])

        text_in = np.zeros((1,gen.max_token_len))
        text_in[0][:] = np.full((gen.max_token_len,), 0)
        text_in[0][0] = 4230

        # print features,text_in
        arr = []
        zeros = np.array([np.zeros(512)])
        for arg in range(gen.max_token_len-1):
            pred = model.predict([features, text_in])
            tok = np.argmax(pred[0][arg])
            word = gen.id_to_token[tok]
            text_in[0][arg+1] = tok
            if word == '<end>':
                break
            arr.append(word)

        arr.append('.')
        cap_obj = Caption(img_name)
        cap_obj.add(caption_number, caption)
        cap_obj.result = ' '.join(arr)
        test_results.update({img_name: cap_obj})
        
        import pickle
        pickle.dump(test_results, open('../data/flicker8k/preprocessed/test_results.p', 'wb') )
bar.update(len(captions))

# Sample code for importing tested results

In [None]:
import pickle

class Caption:
    def __init__(self, name):
        self.name = name
        self.captions = ['','','','','']
        self.result = ''
        
    def add(self, caption_number, caption):
        self.captions[caption_number] = caption
        
test_results = pickle.load(open('../data/flicker8k/preprocessed/test_results.p', 'rb'))

for img_name in test_results:
    ground_truth = test_results[img_name].captions
    result = test_results[img_name].result