# Image genearation Captions
## Introduction
This assignment aims to describe the content of an image by using CNNs and RNNs to build an Image Caption Generator. The model would be based on Tensorflow and Keras. The dataset used is Flickr 8K [5], consisting of 8,000 images each one paired with five different captions to provide clear descriptions. 

The model architectures consists of a CNN which extracts the features and encodes the input image and a Recurrent Neural Network (RNN) based on Long Short Term Memory (LSTM) layers. The most significant difference with other models is that the image embedding is provided as the first input to the RNN network and only once.

## Dependencies

In [1]:
import re
import random

import numpy as np
import tensorflow as tf
import matplotlib.pyplot as plt

from tensorflow import keras
from time import time

from tqdm import tqdm # progress bar
from sklearn.model_selection import train_test_split # Dividing train test
from nltk.translate.bleu_score import corpus_bleu # BLEU Score

## Dataset
Load dataset from local path or google drive

In [2]:
# Change this path to the dataset downloaded from Flickr8 [5]
dataset_path = "E:/Project/Image genearation Captions/Resources"
dataset_images_path = dataset_path + "/Images/" 

Images configuration

In [3]:
img_height = 180
img_width = 180
validation_split = 0.2

### Encoder Model

In order to extract the features from the images, a pretrained CNN model, named Inception V3 was used. In the figure below, there is the representation of the architecture of the used network.

![Inception Architecture](https://paperswithcode.com/media/methods/inceptionv3onc--oview_vjAbOfw.png)

In [4]:
# Remove the last layer of the Inception V3 model
def get_encoder():
    image_model = tf.keras.applications.InceptionV3(include_top=False, weights='imagenet')
    new_input = image_model.input
    hidden_layer = image_model.layers[-1].output

    image_features_extract_model = tf.keras.Model(new_input, hidden_layer)
    return image_features_extract_model

### Read captions
Create dictionary with picture filename as the key and an array of captions as the value

In [5]:
# Preprocess the caption, splitting the string and adding <start> and <end> tokens
def get_preprocessed_caption(caption):    
    caption = re.sub(r'\s+', ' ', caption)
    caption = caption.strip()
    caption = "<start> " + caption + " <end>"
    return caption

In [6]:
images_captions_dict = {}

with open(dataset_path + "/captions.txt", "r") as dataset_info:
    next(dataset_info) # Omit header: image, caption

    # Using a subset of 4,000 entries out of 40,000
    for info_raw in list(dataset_info)[:4000]:
        info = info_raw.split(",")
        image_filename = info[0]
        caption = get_preprocessed_caption(info[1])

        if image_filename not in images_captions_dict.keys():
            images_captions_dict[image_filename] = [caption]
        else:
            images_captions_dict[image_filename].append(caption)

### Read images
Create dictionary with image filename as key and the image feature extracted using the pretrained model as the value.

In [7]:
def load_image(image_path):
    img = tf.io.read_file(dataset_images_path + image_path)
    img = tf.image.decode_jpeg(img, channels=3)
    img = tf.image.resize(img, (img_height, img_width))
    img = tf.keras.applications.inception_v3.preprocess_input(img) # preprocessing needed for pre-trained model
    return img, image_path

In [8]:
image_captions_dict_keys = list(images_captions_dict.keys())
image_dataset = tf.data.Dataset.from_tensor_slices(image_captions_dict_keys)
image_dataset = image_dataset.map(load_image, num_parallel_calls=tf.data.experimental.AUTOTUNE).batch(64)

In [9]:
images_dict = {}
encoder = get_encoder()
for img_tensor, path_tensor in tqdm(image_dataset):
    batch_features_tensor = encoder(img_tensor)
    
    # Loop over batch to save each element in images_dict
    for batch_features, path in zip(batch_features_tensor, path_tensor):
        decoded_path = path.numpy().decode("utf-8")
        images_dict[decoded_path] = batch_features.numpy()

100%|██████████| 13/13 [00:20<00:00,  1.60s/it]
