# **Imports and Setup**

In [1]:
import sys
sys.path.append("..")  # Adjust if needed based on directory

import json
import pickle
import numpy as np
import torch
import torch.nn as nn
from pathlib import Path
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import DataLoader

from utils.dataloader import get_transforms, load_split_ids, build_caption_dataset
from utils.caption_dataset import CaptionDataset

# Device configuration
device = torch.device("cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu")
print("Using device:", device)

ModuleNotFoundError: No module named 'utils'

# **Load Vocabulary and Caption Data**

In [None]:
# Load vocab
with open("../data/processed/word2idx.json", "r") as f:
    word2idx = json.load(f)

with open("../data/processed/image_caption_seqs.pkl", "rb") as f:
    image_caption_seqs = pickle.load(f)

vocab_size = len(word2idx)

# **Dataset Preparation**

In [None]:
# Load image splits
train_ids = load_split_ids("../data/Flickr8k_text/Flickr_8k.trainImages.txt")
val_ids   = load_split_ids("../data/Flickr8k_text/Flickr_8k.devImages.txt")
test_ids  = load_split_ids("../data/Flickr8k_text/Flickr_8k.testImages.txt")

# Transforms and image folder path
transform_train = get_transforms("train")
transform_val = get_transforms("val")
image_folder = "../data/Flicker8k_Dataset"

# Build datasets
train_dataset = build_caption_dataset(train_ids, image_caption_seqs, word2idx, image_folder, transform_train)
val_dataset   = build_caption_dataset(val_ids, image_caption_seqs, word2idx, image_folder, transform_val)

# **Encoder with EfficientNet-B3**

In [None]:
from efficientnet_pytorch import EfficientNet

class EncoderEfficientNet(nn.Module):
    def __init__(self, encoded_image_size=14, fine_tune=True):
        super(EncoderEfficientNet, self).__init__()
        self.enc_image_size = encoded_image_size

        # Load pretrained EfficientNet-B3
        self.efficientnet = EfficientNet.from_pretrained('efficientnet-b3')

        # Remove the final classification layer
        self.features = self.efficientnet.extract_features

        # Resize feature map to fixed size
        self.adaptive_pool = nn.AdaptiveAvgPool2d((encoded_image_size, encoded_image_size))

        self.fine_tune(fine_tune)

    def forward(self, images):
        out = self.features(images)                      # (batch_size, channels, H, W)
        out = self.adaptive_pool(out)                    # (batch_size, channels, 14, 14)
        out = out.permute(0, 2, 3, 1)                     # (batch_size, 14, 14, channels)
        return out

    def fine_tune(self, fine_tune=True):
        for p in self.efficientnet.parameters():
            p.requires_grad = False
        if fine_tune:
            for name, param in self.efficientnet.named_parameters():
                if "blocks.5" in name or "blocks.6" in name or "blocks.7" in name:
                    param.requires_grad = True

# **Bahdanau Attention**

In [None]:
class BahdanauAttention(nn.Module):
    def __init__(self, encoder_dim, hidden_dim, attention_dim):
        super(BahdanauAttention, self).__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(hidden_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)
        self.relu = nn.ReLU()
        self.softmax = nn.Softmax(dim=1)

    def forward(self, encoder_out, hidden_state):
        # encoder_out: (batch_size, num_pixels, encoder_dim)
        # hidden_state: (batch_size, hidden_dim)

        att1 = self.encoder_att(encoder_out)            # (batch_size, num_pixels, attention_dim)
        att2 = self.decoder_att(hidden_state).unsqueeze(1)  # (batch_size, 1, attention_dim)
        att = self.full_att(self.relu(att1 + att2)).squeeze(2)  # (batch_size, num_pixels)
        alpha = self.softmax(att)                       # (batch_size, num_pixels)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)  # (batch_size, encoder_dim)
        return context, alpha

# **GRU Decoder with Attention**

In [None]:
class DecoderGRUWithAttention(nn.Module):
    def __init__(self, attention_dim, embed_dim, hidden_dim, vocab_size, encoder_dim=1536, dropout=0.5):
        super(DecoderGRUWithAttention, self).__init__()
        self.encoder_dim = encoder_dim
        self.embed_dim = embed_dim
        self.hidden_dim = hidden_dim
        self.vocab_size = vocab_size
        self.dropout = dropout

        self.attention = BahdanauAttention(encoder_dim, hidden_dim, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_dim)
        self.dropout_layer = nn.Dropout(dropout)

        self.gru = nn.GRU(embed_dim + encoder_dim, hidden_dim, batch_first=True)
        self.fc = nn.Linear(hidden_dim, vocab_size)
        self.init_hidden = nn.Linear(encoder_dim, hidden_dim)

    def forward(self, encoder_out, captions):
        batch_size = captions.size(0)
        max_len = captions.size(1)

        encoder_out = encoder_out.view(batch_size, -1, self.encoder_dim)  # (B, num_pixels, encoder_dim)
        embeddings = self.embedding(captions)                             # (B, max_len, embed_dim)
        hidden = self.init_hidden(encoder_out.mean(dim=1)).unsqueeze(0)  # (1, B, hidden_dim)

        outputs = torch.zeros(batch_size, max_len, self.vocab_size).to(captions.device)

        for t in range(max_len):
            context, _ = self.attention(encoder_out, hidden.squeeze(0))  # context: (B, encoder_dim)
            input_t = torch.cat([embeddings[:, t, :], context], dim=1).unsqueeze(1)  # (B, 1, embed+encoder)
            output, hidden = self.gru(input_t, hidden)  # output: (B, 1, hidden_dim)
            output = self.fc(self.dropout_layer(output.squeeze(1)))  # (B, vocab_size)
            outputs[:, t, :] = output

        return outputs

# **Model Wrapper**

In [None]:
lass CaptioningModel(nn.Module):
    def __init__(self, encoder, decoder):
        super(CaptioningModel, self).__init__()
        self.encoder = encoder
        self.decoder = decoder

    def forward(self, images, captions):
        encoder_out = self.encoder(images)
        outputs = self.decoder(encoder_out, captions)
        return outputs

# **Training Function**

# **Instantiate and Train**

In [None]:
# Hyperparameters
embed_dim = 256
hidden_dim = 512
attention_dim = 256
dropout = 0.5

encoder = EncoderEfficientNet(encoded_image_size=14, fine_tune=True)
decoder = DecoderGRUWithAttention(
    attention_dim=attention_dim,
    embed_dim=embed_dim,
    hidden_dim=hidden_dim,
    vocab_size=vocab_size,
    encoder_dim=1536,  # EfficientNet-B3 output channels
    dropout=dropout
)

model = CaptioningModel(encoder, decoder).to(device)

# Train
trained_model = train_model(
    model=model,
    train_dataset=train_dataset,
    val_dataset=val_dataset,
    word2idx=word2idx,
    device=device,
    batch_size=8,
    epochs=20,
    patience=3,
    lr=1e-4
)

# **Save the Model**

In [None]:
torch.save(model.state_dict(), "../data/experiment_tumadhir_model.pth")
torch.save(model, "../data/experiment_tumadhir_model_full.pth")

# **BLEU Score Evaluation on the Test Set**

In [None]:
from nltk.translate.bleu_score import corpus_bleu
from tqdm import tqdm

def evaluate_model(model, test_dataset, word2idx, device='cuda', batch_size=32):
    model.eval()
    pad_idx = word2idx['<pad>']
    start_idx = word2idx['<start>']
    end_idx = word2idx['<end>']

    test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
    references = []
    hypotheses = []

    with torch.no_grad():
        for images, captions, lengths in tqdm(test_loader, desc="Evaluating on Test Set"):
            images, captions = images.to(device), captions.to(device)
            outputs = model(images, captions[:, :-1])  # Teacher forcing for consistency
            preds = torch.argmax(outputs, dim=2)

            for ref, pred in zip(captions, preds):
                ref_tokens = [w for w in ref.tolist() if w not in {pad_idx, start_idx, end_idx}]
                pred_tokens = [w for w in pred.tolist() if w not in {pad_idx, start_idx, end_idx}]
                references.append([ref_tokens])
                hypotheses.append(pred_tokens)

    bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
    bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

    print(f"\nFinal BLEU Scores on Test Set:")
    print(f"BLEU-1 = {bleu1:.4f}")
    print(f"BLEU-2 = {bleu2:.4f}")
    print(f"BLEU-3 = {bleu3:.4f}")
    print(f"BLEU-4 = {bleu4:.4f}")

# **Run Evaluation**

In [None]:
evaluate_model(trained_model, test_dataset, word2idx, device=device, batch_size=32)

# **Caption Generation for Sample Images**

In [None]:
from PIL import Image

def generate_caption(model, image_path, word2idx, idx2word, transform, max_len=20):
    model.eval()

    image = Image.open(image_path).convert("RGB")
    image_tensor = transform(image).unsqueeze(0).to(device)

    with torch.no_grad():
        encoder_out = model.encoder(image_tensor)
        encoder_out = encoder_out.view(1, -1, model.decoder.encoder_dim)

        hidden = model.decoder.init_hidden(encoder_out.mean(dim=1)).unsqueeze(0)
        word_input = torch.tensor([[word2idx["<start>"]]]).to(device)

        caption = []

        for _ in range(max_len):
            embedding = model.decoder.embedding(word_input).squeeze(1)
            context, _ = model.decoder.attention(encoder_out, hidden.squeeze(0))
            input_t = torch.cat([embedding, context], dim=1).unsqueeze(1)
            output, hidden = model.decoder.gru(input_t, hidden)
            output = model.decoder.fc(output.squeeze(1))

            word_idx = output.argmax(dim=1).item()
            if word_idx == word2idx["<end>"]:
                break

            caption.append(idx2word.get(str(word_idx), "<unk>"))
            word_input = torch.tensor([[word_idx]]).to(device)

    return " ".join(caption)

Use the Function to Generate a Caption

In [None]:
# Load idx2word mapping if not already loaded
with open("../data/processed/idx2word.json") as f:
    idx2word = json.load(f)

# Apply the same transform as validation
transform = get_transforms("val")

# Choose a test image from the Flickr8k dataset
sample_img = "../data/Flicker8k_Dataset/1000268201_693b08cb0e.jpg"

# Generate and print the caption
caption = generate_caption(trained_model, sample_img, word2idx, idx2word, transform)
print("Generated Caption:", caption)