In [None]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import nltk 
import tqdm
import json
import csv

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader
from vocabulary_class import Vocabulary
from rouge_score import rouge_scorer
from jiwer import wer
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from collections import Counter
from P2_Resnet_LSTM_flikr_8k.vocabulary_class import Vocabulary
from P2_Resnet_LSTM_flikr_8k.flickr_dataset import FlickrDataset
from model import AttentionEncoderViT, AttentionDecoderRNN, TextEncoder
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")

IMAGES_PATH = f"{datasets}/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt"  # Caption file
TEST_IMAGES_PATH = "..\\test_images"  # Directory with test images

In [None]:
tokens = []
counter = Counter()

import csv

def build_vocab(json_path, threshold=5, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)  # skip header: image,caption

        for row in reader:
            if len(row) < 2:
                continue
            img_name, caption = row
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = nltk.tokenize.word_tokenize(caption)
            counter.update(tokens)
            count +=1
            if limit and count >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab, image_captions

In [None]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

In [None]:
len(image_captions)

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
def collate_fn(batch):
    batch.sort(key=lambda x: len(x[1]), reverse=True)

    images, captions, image_ids = zip(*batch)

    images = torch.stack(images, 0)

    lengths = [len(cap) for cap in captions]
    max_len = max(lengths)

    padded_caps = torch.zeros(len(captions), max_len).long()
    for i, cap in enumerate(captions):
        end = lengths[i]
        padded_caps[i, :end] = cap[:end]

    return images, padded_caps, lengths, image_ids

In [None]:
train_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    train_dataset,
    batch_size=16,      # ⬅️ IMPORTANT
    shuffle=True,
    num_workers=0,     # ⬅️ CRITICAL FOR WINDOWS
    collate_fn=collate_fn,
    pin_memory=False   # ⬅️ disable for debugging
)

In [None]:
images, captions, lengths = next(iter(train_loader))
print(images.shape, captions.shape, lengths.shape)

In [None]:
print(len(train_dataset))

In [None]:
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

In [None]:
encoder = AttentionEncoderViT().to(device)

decoder = AttentionDecoderRNN(
    embed_size=256,
    hidden_size=512,
    vocab_size=len(vocab),
    encoder_dim=768
).to(device)

criterion = nn.CrossEntropyLoss(
    ignore_index=vocab.word2idx["<pad>"],
    label_smoothing=0.1
)

params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.AdamW([
    {"params": encoder.parameters(), "lr": 5e-5},
    {"params": decoder.parameters(), "lr": 1e-3},
], weight_decay=1e-4)

scaler = torch.cuda.amp.GradScaler()

In [None]:
train_losses = []
total_epochs = 15

for p in encoder.vit.blocks[-2].parameters():
    p.requires_grad = True

for epoch in range(total_epochs):
    encoder.train()
    decoder.train()

    teacher_forcing_ratio = max(0.9 * (0.95 ** epoch), 0.1)

    total_train_loss = 0

    for images, captions, lengths in tqdm.tqdm(train_loader):
        images = images.to(device, non_blocking=True)
        captions = captions.to(device, non_blocking=True)

        optimizer.zero_grad(set_to_none=True)
        with torch.cuda.amp.autocast():
            features = encoder(images)

            outputs = decoder(
                features,
                captions,
                lengths=lengths,
                teacher_forcing_ratio=teacher_forcing_ratio
            )
            targets = captions[:, 1:]
            targets = targets[:, :outputs.size(1)]

            loss = criterion(
                outputs.reshape(-1, len(vocab)),
                targets.reshape(-1)
            )

        # ✅ AMP-safe backward
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
        total_train_loss += loss.item()

    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)

    print(
        f"Epoch [{epoch+1}/{total_epochs}] | "
        f"Loss: {avg_train_loss:.4f} | "
        f"Teacher forcing: {teacher_forcing_ratio:.3f}"
    )

    torch.cuda.empty_cache()


In [None]:
def generate_caption_beam(
    image,
    encoder,
    decoder,
    vocab,
    beam_size=5,
    max_len=20
):
    encoder.eval()
    decoder.eval()

    image = image.unsqueeze(0).to(device)
    encoder_out = encoder(image)  # [1, N, 256]

    start = vocab.word2idx["<start>"]
    end = vocab.word2idx["<end>"]

    num_patches = encoder_out.size(1)  # e.g. 196
    coverage0 = torch.zeros(1, num_patches).to(device)

    sequences = [[ [start], 0.0, None, None, coverage0 ]]

    completed = []

    for _ in range(max_len):
        all_candidates = []

        for seq, score, h, c, coverage in sequences:
            if seq[-1] == end:
                completed.append((seq, score))
                continue

            inputs = torch.LongTensor([[seq[-1]]]).to(device)
            emb = decoder.embed(inputs).squeeze(1)

            if h is None:
                h0 = torch.zeros(1, decoder.lstm.hidden_size).to(device)
                c0 = torch.zeros_like(h0)

                context, alpha, coverage = decoder.attention(
                    encoder_out, h0, coverage
                )
                lstm_input = torch.cat((emb, context), dim=1)
                h, c = decoder.lstm(lstm_input, (h0, c0))
            else:
                context, alpha, coverage = decoder.attention(
                    encoder_out, h, coverage
                )
                lstm_input = torch.cat((emb, context), dim=1)
                h, c = decoder.lstm(lstm_input, (h, c))

            logits = decoder.fc(h)
            log_probs = torch.log_softmax(logits, dim=1)

            topk = torch.topk(log_probs, beam_size)

            for i in range(beam_size):
                token = topk.indices[0][i].item()
                log_prob = topk.values[0][i].item()

                length_penalty = 0.7
                new_score = (score + log_prob) / ((len(seq) + 1) ** length_penalty)

                all_candidates.append([
                    seq + [token],
                    new_score,
                    h,
                    c,
                    coverage
                ])

        sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_size]

    completed += [(seq, score) for seq, score, _, _, _ in sequences]
    best_seq = max(completed, key=lambda x: x[1])[0]

    return " ".join([vocab.idx2word[i] for i in best_seq if i not in [start, end]])


In [None]:
torch.save(encoder.state_dict(), f"{models}/encoder-Vit-LSTM-Attention.pth")
torch.save(decoder.state_dict(), f"{models}/decoder-VLA.pth")
torch.save(vocab, f"{models}/vocab-VLA.pkl")

In [None]:
vocab = torch.load(f"{models}/vocab-VLA.pkl", weights_only=False)
encoder = AttentionEncoderViT().to(device)

decoder = AttentionDecoderRNN(
    embed_size=256,
    hidden_size=512,
    vocab_size=len(vocab),
    encoder_dim=768
).to(device)

encoder.load_state_dict(torch.load(f"{models}/encoder-Vit-LSTM-Attention.pth", weights_only=True, map_location=device))
decoder.load_state_dict(torch.load(f"{models}/decoder-VLA.pth", weights_only=True, map_location=device), strict=False)

In [None]:
from rouge_score import rouge_scorer

def tokens_to_string(tokens):
    if isinstance(tokens, list):
        return " ".join(tokens)
    return tokens

def compute_rouge_l(references, hypotheses):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

    scores = []

    for refs, hyp in zip(references, hypotheses):
        hyp_str = tokens_to_string(hyp)

        rouge_l_scores = []
        for ref in refs:
            ref_str = tokens_to_string(ref)
            score = scorer.score(ref_str, hyp_str)['rougeL'].fmeasure
            rouge_l_scores.append(score)

        scores.append(max(rouge_l_scores))  # best reference

    return sum(scores) / len(scores)

from jiwer import wer
def compute_wer(references, hypotheses):
    """
    Computes WER using best matching reference per hypothesis
    """
    wers = []
    ref_tokens = []
    for refs in references:
        processed_refs = []
        for ref in refs:
            if isinstance(ref, list):
                processed_refs.append(ref)        # already tokenized
            else:
                processed_refs.append(ref.split()) # string → tokens
        ref_tokens.append(processed_refs)

    # Handle hypotheses
    hyp_tokens = []
    for hyp in hypotheses:
        if isinstance(hyp, list):
            hyp_tokens.append(hyp)
        else:
            hyp_tokens.append(hyp.split())

    for refs, hyp in zip(references, hypotheses):
        ref_wers = [wer(ref, hyp) for ref in refs]
        wers.append(min(ref_wers))  # Best match

    return sum(wers) / len(wers)

from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

def compute_bleu4(references, hypotheses):
    """
    references: list of list of reference captions
    hypotheses: list of predicted captions
    """
    ref_tokens = []
    for refs in references:
        processed_refs = []
        for ref in refs:
            if isinstance(ref, list):
                processed_refs.append(ref)        # already tokenized
            else:
                processed_refs.append(ref.split()) # string → tokens
        ref_tokens.append(processed_refs)

    # Handle hypotheses
    hyp_tokens = []
    for hyp in hypotheses:
        if isinstance(hyp, list):
            hyp_tokens.append(hyp)
        else:
            hyp_tokens.append(hyp.split())


    smoothie = SmoothingFunction().method4

    bleu4 = corpus_bleu(
        ref_tokens,
        hyp_tokens,
        weights=(0.25, 0.25, 0.25, 0.25),
        # weights=(0, 0, 0, 0),
        smoothing_function=smoothie
    )

    return bleu4

In [None]:
# Bluen score caluclation
def wer_best_reference(hyp_tokens, refs_tokens):
    hyp = " ".join(hyp_tokens)   # ✅ tokens → string
    return min(
        wer(" ".join(ref), hyp)  # ✅ tokens → string
        for ref in refs_tokens
    )

references = []
hypotheses = []

encoder.eval()
decoder.eval()

# Build reference captions per image
image_to_refs = {}

for img_name, caps in train_dataset.image_captions.items():
    refs = []
    for cap in caps:
        tokens = nltk.tokenize.word_tokenize(cap.lower())
        refs.append(tokens)
    image_to_refs[img_name] = refs

with torch.no_grad():
    for images, captions, lengths, image_ids in tqdm.tqdm(train_loader):
        images = images.to(device)

        for i in range(images.size(0)):
            hyp = generate_caption_beam(
                images[i],
                encoder,
                decoder,
                vocab,
                beam_size=5,
                max_len=20
            ).split()

            img_name = image_ids[i]
            refs = image_to_refs[img_name]   # ✅ ALL 5 captions

            references.append(refs)
            hypotheses.append(hyp)
      
bleu4 = compute_bleu4(references, hypotheses)
print(f"blue 4 scores:: {bleu4:.4f}")

rougue_score = compute_rouge_l(references, hypotheses)
print(f"Rouge Score: {rougue_score:.4f}")

wers = []

for hyp, refs in zip(hypotheses, references):
    wers.append(wer_best_reference(hyp, refs))

final_wer = sum(wers) / len(wers)
print(f"WER Score: {final_wer:.4f}")

In [None]:
def recall_at_k(similarity, k):
    """
    similarity: [N, N] similarity matrix
    """
    topk = similarity.topk(k, dim=1).indices
    targets = torch.arange(similarity.size(0)).unsqueeze(1).to(similarity.device)
    correct = (topk == targets).any(dim=1)
    return correct.float().mean().item()

def extract_image_embeddings(dataloader, encoder, device):
    encoder.eval()
    image_embeddings = []

    with torch.no_grad():
        for images, _ , lengths in dataloader:
            images = images.to(device)
            patch_emb = encoder(images)               # [B, 196, 256]
            feats = patch_emb.mean(dim=1)            # [B, 256]
            feats = F.normalize(feats, dim=1)
            image_embeddings.append(feats)

    return torch.cat(image_embeddings, dim=0)       # [N, embed_size]

def extract_text_embeddings(dataloader, text_encoder, device):
    text_encoder.eval()
    all_sent = []

    with torch.no_grad():
        for _, captions, lengths in dataloader:
            captions = captions.to(device)
            lengths = torch.tensor(lengths)

            sent = text_encoder(captions, lengths)
            all_sent.append(sent.cpu())

    return torch.cat(all_sent, dim=0)

def image_to_text_retrieval(image_emb, text_emb, batch_size=512):
    image_emb = image_emb.cuda()
    text_emb  = text_emb.cuda()

    N = image_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        img_batch = image_emb[i:i+batch_size]           # [B, D]
        sim = img_batch @ text_emb.T                     # [B, N]

        gt = torch.arange(i, min(i+batch_size, N)).cuda()
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def text_to_image_retrieval(image_emb, text_emb, batch_size=512):
    image_emb = image_emb.cuda()
    text_emb  = text_emb.cuda()

    N = text_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        txt_batch = text_emb[i:i+batch_size]           # [B, D]
        sim = txt_batch @ image_emb.T                     # [B, N]

        gt = torch.arange(i, min(i+batch_size, N)).cuda()
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def text_to_text_retrieval(text_emb, batch_size=512):
    text_emb  = text_emb.cuda()

    N = text_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        txt_batch = text_emb[i:i+batch_size]           # [B, D]
        sim = txt_batch @ text_emb.T                     # [B, N]
        sim.fill_diagonal_(-1)

        gt = torch.arange(i, min(i+batch_size, N)).cuda()
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def image_to_image_retrieval(image_emb):
    sim = image_emb @ image_emb.t()

    # Remove self-matching
    sim.fill_diagonal_(-1)

    return {
        "R@1": recall_at_k(sim, 1),
        "R@5": recall_at_k(sim, 5),
        "R@10": recall_at_k(sim, 10),
    }

In [None]:
image_emb = extract_image_embeddings(train_loader, encoder, device)

In [None]:
text_encoder = TextEncoder(
    vocab_size=len(vocab),
    embed_size=256,
    hidden_size=512
).to(device)

text_emb = extract_text_embeddings(train_loader, text_encoder, device)

In [None]:
print(image_emb.shape)  # [num_images, 256]
print(text_emb.shape)   # [num_images*5, 256]

In [None]:
print("Image → Text:", image_to_text_retrieval(image_emb, text_emb))
print("Text → Image:", text_to_image_retrieval(image_emb, text_emb))
print("Text → Text :", text_to_text_retrieval(text_emb))
print("Image → Image:", image_to_image_retrieval(image_emb))

In [None]:
#  test the model on a few images
for file_name in os.listdir(TEST_IMAGES_PATH)[:20]:
    # Load image
    img_path = f"{TEST_IMAGES_PATH}/{file_name}"
    image = Image.open(img_path).convert("RGB")

    image = transform(image)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()
    
    caption = generate_caption_beam(image, encoder, decoder, vocab)
    print(f"Image: {file_name}\nCaption: {caption}\n")
