In [4]:
# libraries that will be used for the project
import numpy as np
import tensorflow as tf
#import matplotlib.pyplot as plt

import os
import re
from tensorflow import keras
from tensorflow.keras import layers
from tensorflow.keras.layers import TextVectorization

## Preparing Dataset

In [23]:
def load_data(file):
    with open(file) as cap_file:
        caption = cap_file.readlines()
        cap_map = {}
        text_data = []
        skip_img = set()
        
    for i in caption:
        i = i.rstrip('\n')
        img, cap = i.split('\t')
        
        img = img.split("#")[0]
        img = os.path.join("Flicker8k_Dataset", img.strip())
        
        token = cap.strip().split()
        if len(token) < 5 or len(token) > 25:
            skip_img.add(img)
            continue
        if img.endswith('jpg') and img not in skip_img:
            cap =  "<start> " + cap.strip() + " <end>"
            text_data.append(cap)
        if img in cap_map:
            cap_map[img].append(cap)
        else:
            cap_map[img] = [cap]
            
    for img in skip_img:
        if img in cap_map:
            del cap_map[img]
    return cap_map, text_data

In [28]:
def split_train_val(cap,train_size=0.8, shuffle = True):
    all_img = list(cap.keys())
    if shuffle:
        np.random.shuffle(all_img)
        
    train_size = int(len(cap) * train_size)
    training_data = {img: cap[img] for img in all_img[:train_size]}
    validation_data = {img: cap[img] for img in all_img[train_size:]}
    
    return training_data, validation_data

In [24]:
cap_map, text_data = load_data("Flickr8k_text/Flickr8k.token.txt")

In [29]:
train_data, valid_data = split_train_val(cap_map)
print("Number of training samples: ", len(train_data))
print("Number of validation samples: ", len(valid_data))

Number of training samples:  6115
Number of validation samples:  1529


### Vectorizing the text Data

In [34]:
def standardization(input_string):
    lowercase = tf.strings.lower(input_string)
    return tf.strings.regex_replace(lowercase, "[%s]" % re.escape(strip_chars), "")

In [35]:
strip_chars = "!\"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"
strip_chars = strip_chars.replace("<", "")
strip_chars = strip_chars.replace(">", "")

vectorization = TextVectorization(
    max_tokens=10000,
    output_mode="int",
    output_sequence_length=25,
    standardize=standardization,
)
vectorization.adapt(text_data)

# Data augmentation for image data
image_augmentation = keras.Sequential(
    [
        layers.RandomFlip("horizontal"),
        layers.RandomRotation(0.2),
        layers.RandomContrast(0.3),
    ]
)