<a href="https://colab.research.google.com/github/mamunm/iamge_caption_generator/blob/main/notebooks/Flickr8k_Data_Processing.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

Guided notebook for image captioning which is inspired by the machine learning mastery blog.

In [None]:
!pip install wandb -qqq

In [None]:
from google.colab import drive
drive.mount('/content/drive')

In [None]:
# import python modules
import os
import pickle
from tqdm import tqdm
import string 
from keras.applications.vgg16 import VGG16
from keras.preprocessing.image import load_img
from keras.preprocessing.image import img_to_array
from keras.applications.vgg16 import preprocess_input
from keras.models import Model
from collections import defaultdict

import numpy as np
#import wandb

In [None]:
# extract image features from a collection of images in a directory
def extract_features(directory):
    model = VGG16()
    model = Model(inputs=model.inputs, outputs=model.layers[-2].output)
    features = dict()
    for name in tqdm(os.listdir(directory)):
        filename = directory + '/' + name
        image = load_img(filename, target_size=(224, 224))
        image = img_to_array(image)
        image = image.reshape((1, *image.shape))
        image = preprocess_input(image)
        feature = model.predict(image, verbose=0)
        image_id = name.split('.')[0]
        features[image_id] = feature
    return features

In [None]:
# extract features from all images
data_path = 'drive/MyDrive/image_captioning_data/Flickr8K'
if not os.path.exists(os.path.join(data_path, 'features.pkl')):  
    directory = os.path.join(data_path, 'Flicker8k_Dataset')
    features = extract_features(directory)
    print(f'Extracted Features: {len(features)}')
    pickle.dump(features, open(os.path.join(data_path, 
                                            'features.pkl'), 'wb'))
else:
    features = pickle.load(open(os.path.join(data_path, 
                                             'features.pkl'), 'rb'))


In [None]:
# load doc into memory
with open(os.path.join(data_path, 
                       'Flickr8k_text/Flickr8k.token.txt'), 'r') as f:
    doc = f.read()

In [None]:
# extract descriptions for images
def load_descriptions(doc):
	mapping = dict()
	# process lines
	for line in doc.split('\n'):
		# split line by white space
		tokens = line.split()
		if len(line) < 2:
			continue
		image_id, image_desc = tokens[0], tokens[1:]
		# remove filename from image id
		image_id = image_id.split('.')[0]
		# convert description tokens back to string
		image_desc = ' '.join(image_desc)
		# create the list if needed
		if image_id not in mapping:
			mapping[image_id] = list()
		# store description
		mapping[image_id].append(image_desc)
	return mapping

def clean_descriptions(descriptions):
	# prepare translation table for removing punctuation
	table = str.maketrans('', '', string.punctuation)
	for key, desc_list in descriptions.items():
		for i in range(len(desc_list)):
			desc = desc_list[i]
			# tokenize
			desc = desc.split()
			# convert to lower case
			desc = [word.lower() for word in desc]
			# remove punctuation from each token
			desc = [w.translate(table) for w in desc]
			# remove hanging 's' and 'a'
			desc = [word for word in desc if len(word)>1]
			# remove tokens with numbers in them
			desc = [word for word in desc if word.isalpha()]
			# store as string
			desc_list[i] =  ' '.join(desc)

# convert the loaded descriptions into a vocabulary of words
def to_vocabulary(descriptions):
	# build a list of all description strings
	all_desc = set()
	for key in descriptions.keys():
		[all_desc.update(d.split()) for d in descriptions[key]]
	return all_desc

# save descriptions to file, one per line
def save_descriptions(descriptions, filename):
	lines = list()
	for key, desc_list in descriptions.items():
		for desc in desc_list:
			lines.append(key + ' ' + desc)
	data = '\n'.join(lines)
	file = open(filename, 'w')
	file.write(data)
	file.close()

# parse descriptions
descriptions = load_descriptions(doc)
print('Loaded: %d ' % len(descriptions))
# clean descriptions
clean_descriptions(descriptions)
# summarize vocabulary
vocabulary = to_vocabulary(descriptions)
print('Vocabulary Size: %d' % len(vocabulary))
# save to file
save_descriptions(descriptions, os.path.join(data_path, 
                                             'descriptions.txt'))