<a href="https://colab.research.google.com/github/kalana-mihiranga/Image-Captioning/blob/main/Phase_1_Image_Captioning.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [1]:
import os
os.environ["KERAS_BACKEND"] = "tensorflow"
import re
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
from keras import layers
from keras.applications import efficientnet
from keras.layers import TextVectorization
keras.utils.set_random_seed(111)

In [2]:
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_Dataset.zip
!wget -q https://github.com/jbrownlee/Datasets/releases/download/Flickr8k/Flickr8k_text.zip
!unzip -qq Flickr8k_Dataset.zip
!unzip -qq Flickr8k_text.zip
!rm Flickr8k_Dataset.zip Flickr8k_text.zip

In [3]:
IMAGES_PATH = "Flicker8k_Dataset"
IMAGE_SIZE = (255, 255)
VOCAB_SIZE = 10000
SEQ_LENGTH = 25
EMBED_DIM = 512
FF_DIM = 512
BATCH_SIZE = 64
EPOCHS = 10
AUTOTUNE = tf.data.AUTOTUNE

In [4]:
def parse_captions(file_path):
    with open(file_path, 'r') as file:
        lines = file.readlines()

    img_to_captions = {}
    all_captions = []
    excluded_images = set()

    for entry in lines:
        entry = entry.strip()
        image_file, sentence = entry.split("\t")
        image_file = os.path.join(IMAGES_PATH, image_file.split("#")[0].strip())
        word_list = sentence.strip().split()

        if not (5 <= len(word_list) <= SEQ_LENGTH):
            excluded_images.add(image_file)
            continue

        if image_file.endswith(".jpg") and image_file not in excluded_images:
            caption = "<start> " + sentence.strip() + " <end>"
            all_captions.append(caption)
            img_to_captions.setdefault(image_file, []).append(caption)

    for image in excluded_images:
        img_to_captions.pop(image, None)

    return img_to_captions, all_captions


def split_data(data_dict, split_ratio=0.8, shuffle_data=True):
    image_paths = list(data_dict.keys())
    if shuffle_data:
        np.random.shuffle(image_paths)

    split_index = int(len(image_paths) * split_ratio)
    train_subset = {k: data_dict[k] for k in image_paths[:split_index]}
    val_subset = {k: data_dict[k] for k in image_paths[split_index:]}

    return train_subset, val_subset


# Load and split the data
caption_dict, corpus = parse_captions("Flickr8k.token.txt")
train_captions, val_captions = split_data(caption_dict)

print("Training image-caption pairs:", len(train_captions))
print("Validation image-caption pairs:", len(val_captions))


Training image-caption pairs: 6114
Validation image-caption pairs: 1529


In [8]:
# Define characters to be removed from text (except < and > for tokens)
punctuation_to_strip = r"""!"#$%&'()*+,-./:;<=>?@[\]^_`{|}~"""
punctuation_to_strip = punctuation_to_strip.replace("<", "").replace(">", "")

# Custom function for cleaning and standardizing caption text
def standardize_text(text_input):
    lowercase_text = tf.strings.lower(text_input)
    cleaned_text = tf.strings.regex_replace(lowercase_text, f"[{re.escape(punctuation_to_strip)}]", "")
    return cleaned_text

# Initialize the text vectorizer with vocabulary learning
text_vectorizer = TextVectorization(
    max_tokens=VOCAB_SIZE,
    output_mode="int",
    output_sequence_length=SEQ_LENGTH,
    standardize=standardize_text,
)

# Fit the vectorizer on the caption corpus
text_vectorizer.adapt(text_data)

# Define image augmentation pipeline for visual robustness
augmentation_pipeline = keras.Sequential([
    layers.RandomFlip("horizontal"),
    layers.RandomRotation(0.2),
    layers.RandomContrast(0.3),
    layers.RandomZoom(0.1),
    layers.RandomTranslation(0.1, 0.1),
])


In [9]:
# Function to decode, resize, and normalize an image from a file path
def load_and_preprocess_image(image_path):
    image_data = tf.io.read_file(image_path)
    image_tensor = tf.image.decode_jpeg(image_data, channels=3)
    image_tensor = tf.image.resize(image_tensor, IMAGE_SIZE)
    return tf.image.convert_image_dtype(image_tensor, dtype=tf.float32)

# Function to pair preprocessed image with tokenized caption
def prepare_sample(image_path, caption_texts):
    return load_and_preprocess_image(image_path), text_vectorizer(caption_texts)

# Function to build a tf.data pipeline
def create_tf_dataset(image_paths, caption_lists):
    dataset = tf.data.Dataset.from_tensor_slices((image_paths, caption_lists))
    dataset = dataset.shuffle(BATCH_SIZE * 8)
    dataset = dataset.map(prepare_sample, num_parallel_calls=AUTOTUNE)
    dataset = dataset.batch(BATCH_SIZE).prefetch(AUTOTUNE)
    return dataset

# Build the training and validation datasets
train_dataset = create_tf_dataset(list(train_data.keys()), list(train_data.values()))
valid_dataset = create_tf_dataset(list(valid_data.keys()), list(valid_data.values()))


In [10]:
# CNN-based feature extractor using EfficientNetB0
def build_cnn_feature_extractor():
    base_cnn = efficientnet.EfficientNetB0(
        input_shape=(*IMAGE_SIZE, 3),
        include_top=False,
        weights="imagenet"
    )
    base_cnn.trainable = False  # Freeze backbone
    features = layers.Reshape((-1, base_cnn.output.shape[-1]))(base_cnn.output)
    return keras.Model(inputs=base_cnn.input, outputs=features)

# Transformer-style encoder layer
class VisionEncoder(layers.Layer):
    def __init__(self, embed_dim, projection_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()
        self.attention = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim)
        self.dense_proj = layers.Dense(embed_dim, activation="relu")

    def call(self, inputs, training=False, mask=None):
        x = self.norm1(inputs)
        x = self.dense_proj(x)
        attention_output = self.attention(query=x, value=x, key=x, training=training)
        return self.norm2(x + attention_output)

# Positional Embedding for the decoder
class TokenPositionalEmbedding(layers.Layer):
    def __init__(self, vocab_size, sequence_length, embed_dim, **kwargs):
        super().__init__(**kwargs)
        self.token_embed = layers.Embedding(input_dim=vocab_size, output_dim=embed_dim)
        self.position_embed = layers.Embedding(input_dim=sequence_length, output_dim=embed_dim)
        self.scale = tf.math.sqrt(tf.cast(embed_dim, tf.float32))

    def call(self, token_input):
        seq_len = tf.shape(token_input)[-1]
        positions = tf.range(start=0, limit=seq_len, delta=1)
        x = self.token_embed(token_input) * self.scale
        pos = self.position_embed(positions)
        return x + pos

    def compute_mask(self, inputs, mask=None):
        return tf.math.not_equal(inputs, 0)

# Transformer decoder block
class CaptionDecoder(layers.Layer):
    def __init__(self, embed_dim, ff_dim, num_heads, **kwargs):
        super().__init__(**kwargs)
        self.self_attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=0.1)
        self.cross_attn = layers.MultiHeadAttention(num_heads=num_heads, key_dim=embed_dim, dropout=0.1)
        self.ff1 = layers.Dense(ff_dim, activation="relu")
        self.ff2 = layers.Dense(embed_dim)
        self.dropout1 = layers.Dropout(0.3)
        self.dropout2 = layers.Dropout(0.5)
        self.norm1 = layers.LayerNormalization()
        self.norm2 = layers.LayerNormalization()
        self.norm3 = layers.LayerNormalization()
        self.embed = TokenPositionalEmbedding(
            vocab_size=VOCAB_SIZE, sequence_length=SEQ_LENGTH, embed_dim=EMBED_DIM
        )
        self.final_dense = layers.Dense(VOCAB_SIZE, activation="softmax")
        self.supports_masking = True

    def get_causal_mask(self, x):
        seq_len = tf.shape(x)[1]
        i = tf.range(seq_len)[:, tf.newaxis]
        j = tf.range(seq_len)
        mask = tf.cast(i >= j, dtype="int32")
        mask = tf.reshape(mask, (1, seq_len, seq_len))
        batch_size = tf.shape(x)[0]
        return tf.tile(mask, [batch_size, 1, 1])

    def call(self, tokens, encoder_output, training=False, mask=None):
        x = self.embed(tokens)
        causal_mask = self.get_causal_mask(x)
        padding_mask = tf.cast(mask[:, :, tf.newaxis], tf.int32) if mask is not None else None
        combined_mask = tf.minimum(padding_mask, causal_mask) if mask is not None else causal_mask

        x = self.self_attn(x, x, x, attention_mask=combined_mask, training=training)
        x = self.norm1(self.embed(tokens) + x)
        x = self.cross_attn(x, encoder_output, encoder_output, attention_mask=padding_mask, training=training)
        x = self.norm2(x + self.cross_attn(x, encoder_output, encoder_output, training=training))
        x = self.ff1(x)
        x = self.dropout1(x, training=training)
        x = self.ff2(x)
        x = self.norm3(x + x)
        return self.final_dense(self.dropout2(x, training=training))

# Full model with custom training logic
class CaptioningSystem(keras.Model):
    def __init__(self, feature_extractor, encoder_layer, decoder_layer, num_captions=5, augment_fn=None):
        super().__init__()
        self.backbone = feature_extractor
        self.encoder = encoder_layer
        self.decoder = decoder_layer
        self.loss_fn = keras.losses.SparseCategoricalCrossentropy(from_logits=False, reduction=None)
        self.loss_metric = keras.metrics.Mean(name="loss")
        self.acc_metric = keras.metrics.Mean(name="accuracy")
        self.caption_count = num_captions
        self.augment = augment_fn

    def compute_loss_accuracy(self, image_features, token_batch, training=True):
        enc_out = self.encoder(image_features, training=training)
        token_input = token_batch[:, :-1]
        token_target = token_batch[:, 1:]
        mask = tf.math.not_equal(token_target, 0)
        pred_tokens = self.decoder(token_input, enc_out, training=training, mask=mask)

        loss = self.loss_fn(token_target, pred_tokens)
        loss *= tf.cast(mask, tf.float32)
        loss = tf.reduce_sum(loss) / tf.reduce_sum(tf.cast(mask, tf.float32))

        acc = tf.equal(token_target, tf.argmax(pred_tokens, axis=-1))
        acc = tf.math.logical_and(acc, mask)
        acc = tf.reduce_sum(tf.cast(acc, tf.float32)) / tf.reduce_sum(tf.cast(mask, tf.float32))

        return loss, acc

    def train_step(self, batch):
        images, token_sequences = batch
        if self.augment:
            images = self.augment(images)
        embeddings = self.backbone(images)

        total_loss, total_acc = 0.0, 0.0

        for i in range(self.caption_count):
            with tf.GradientTape() as tape:
                loss, acc = self.compute_loss_accuracy(embeddings, token_sequences[:, i, :], training=True)
                total_loss += loss
                total_acc += acc

            variables = self.encoder.trainable_variables + self.decoder.trainable_variables
            gradients = tape.gradient(loss, variables)
            self.optimizer.apply_gradients(zip(gradients, variables))

        avg_loss = total_loss / self.caption_count
        avg_acc = total_acc / self.caption_count
        self.loss_metric.update_state(avg_loss)
        self.acc_metric.update_state(avg_acc)
        return {"loss": self.loss_metric.result(), "accuracy": self.acc_metric.result()}

    def test_step(self, batch):
        images, token_sequences = batch
        embeddings = self.backbone(images)

        total_loss, total_acc = 0.0, 0.0
        for i in range(self.caption_count):
            loss, acc = self.compute_loss_accuracy(embeddings, token_sequences[:, i, :], training=False)
            total_loss += loss
            total_acc += acc

        avg_loss = total_loss / self.caption_count
        avg_acc = total_acc / self.caption_count
        self.loss_metric.update_state(avg_loss)
        self.acc_metric.update_state(avg_acc)
        return {"loss": self.loss_metric.result(), "accuracy": self.acc_metric.result()}

    @property
    def metrics(self):
        return [self.loss_metric, self.acc_metric]

# Instantiate components
cnn_encoder_model = build_cnn_feature_extractor()
cnn_encoder_model.trainable = True

transformer_encoder = VisionEncoder(embed_dim=EMBED_DIM, projection_dim=FF_DIM, num_heads=1)
transformer_decoder = CaptionDecoder(embed_dim=EMBED_DIM, ff_dim=FF_DIM, num_heads=2)

caption_model = CaptioningSystem(
    feature_extractor=cnn_encoder_model,
    encoder_layer=transformer_encoder,
    decoder_layer=transformer_decoder,
    augment_fn=augmentation_pipeline
)


Downloading data from https://storage.googleapis.com/keras-applications/efficientnetb0_notop.h5
[1m16705208/16705208[0m [32m━━━━━━━━━━━━━━━━━━━━[0m[37m[0m [1m0s[0m 0us/step


In [None]:
# Define loss function with masking support for padded tokens
loss_function = keras.losses.SparseCategoricalCrossentropy(
    from_logits=False,
    reduction=None  # Needed to apply custom mask weighting
)

# Optional: Early stopping callback (useful in extended training)
# early_stopping_cb = keras.callbacks.EarlyStopping(
#     patience=3,
#     restore_best_weights=True,
#     monitor='val_loss'
# )

# Custom learning rate schedule with linear warm-up
class WarmupSchedule(keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, target_lr, warmup_steps):
        super().__init__()
        self.target_lr = target_lr
        self.warmup_steps = warmup_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        lr = self.target_lr * (step / warmup_steps)
        return tf.cond(
            step < warmup_steps,
            lambda: lr,
            lambda: self.target_lr
        )

# Compute total training steps and warm-up phase
total_steps = len(train_dataset) * EPOCHS
warmup_phase = total_steps // 10
lr_schedule = WarmupSchedule(target_lr=1e-4, warmup_steps=warmup_phase)

# Set up the optimizer using the custom schedule
optimizer = keras.optimizers.Adam(learning_rate=lr_schedule)

# Compile the model with optimizer and loss function
caption_model.compile(
    optimizer=optimizer,
    loss=loss_function
)

# Train the model
history = caption_model.fit(
    train_dataset,
    epochs=EPOCHS,
    validation_data=valid_dataset,
    # callbacks=[early_stopping_cb]  # Optional
)


Epoch 1/10
[1m15/96[0m [32m━━━[0m[37m━━━━━━━━━━━━━━━━━[0m [1m56:06[0m 42s/step - accuracy: 0.0443 - loss: 8.8507

In [None]:
# Build vocabulary and index mapping for decoding predictions
vocab_list = text_vectorizer.get_vocabulary()
index_to_token = dict(zip(range(len(vocab_list)), vocab_list))
max_caption_len = SEQ_LENGTH - 1
val_image_paths = list(valid_data.keys())

# Function to generate a caption for a randomly selected image
def predict_caption():
    # Select a random image from the validation set
    chosen_image_path = np.random.choice(val_image_paths)

    # Load and preprocess the image
    image_tensor = load_and_preprocess_image(chosen_image_path)
    img_array = image_tensor.numpy().clip(0, 1)
    plt.imshow(img_array)
    plt.axis('off')
    plt.show()

    # Extract features using the CNN encoder
    encoded_image = caption_model.backbone(tf.expand_dims(image_tensor, axis=0))
    visual_features = caption_model.encoder(encoded_image, training=False)

    # Initialize caption generation
    caption_input = "<start>"
    for _ in range(max_caption_len):
        tokenized_input = text_vectorizer([caption_input])[:, :-1]
        mask = tf.math.not_equal(tokenized_input, 0)

        predictions = caption_model.decoder(
            tokenized_input,
            visual_features,
            training=False,
            mask=mask
        )
        next_token_index = tf.argmax(predictions[0, -1, :]).numpy()
        next_token = index_to_token[next_token_index]

        if next_token == "<end>":
            break
        caption_input += f" {next_token}"

    # Format and display the final prediction
    generated_caption = caption_input.replace("<start>", "").replace("<end>", "").strip()
    print("📝 Predicted Caption:", generated_caption)

# Generate captions for sample validation images
predict_caption()
predict_caption()
predict_caption()
