In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
# IMAGES_PATH = "../phase_1/data/flickr8k/images/"  # Directory with training images
# CAPTIONS_PATH = "../phase_1/data/flickr8k/captions_testImages.txt"  # Caption file
# TEST_IMAGES_PATH = "../phase_1/test copy/"  # Directory with test images

datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")

IMAGES_PATH = f"{datasets}/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt"  # Caption file
TEST_IMAGES_PATH = "..\\test_images"  # Directory with test images

In [None]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class  import Vocabulary
nltk.download('punkt_tab')
import json
import csv

tokens = []
counter = Counter()

def build_vocab(json_path, threshold=5, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)  # skip header: image,caption

        for row in reader:
            if len(row) < 2:
                continue
            img_name, caption = row
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = nltk.tokenize.word_tokenize(caption)
            counter.update(tokens)
            count +=1
            if limit and count >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab, image_captions

In [None]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

In [None]:
import torch
from torch.utils.data import Dataset
from PIL import Image
import nltk
import os
import random

# Update the folowing datset call class to handle flikr8k dataset formate


import torch
from torch.utils.data import Dataset
from PIL import Image
import nltk
import os
import random
import csv

class FlickrDataset(Dataset):
    def __init__(self, root, captions_path, vocab, transform=None, max_samples=None):
        self.root = root
        self.vocab = vocab
        self.transform = transform

        self.samples = []  # (img_name, caption)

        with open(captions_path, "r", encoding="utf-8") as f:
            reader = csv.reader(f)
            next(reader)
            for img_name, caption in reader:
                self.samples.append((img_name, caption))

            if max_samples:
                self.samples = self.samples[:max_samples]

    def __len__(self):
        return len(self.samples)

    def __getitem__(self, index):
        img_name, caption = self.samples[index]
        img_path = os.path.join(self.root, img_name)

        image = Image.open(img_path).convert("RGB")
        if self.transform:
            image = self.transform(image)

        tokens = nltk.tokenize.word_tokenize(caption)
        caption_indices = [self.vocab.word2idx["<start>"]] + \
            [self.vocab.word2idx.get(t, self.vocab.word2idx["<unk>"]) for t in tokens] + \
            [self.vocab.word2idx["<end>"]]

        caption_tensor = torch.tensor(caption_indices)

        return image, caption_tensor, img_name


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def collate_fn(batch):
    images, captions, ids = zip(*batch)

    images = torch.stack(images, 0)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions,ids

from torch.utils.data import DataLoader
# from flickr_dataset  import FlickrDataset 

test_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform
)

test_loader = DataLoader(
    dataset=test_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn= collate_fn
)

print(len(test_dataset))
image, caption, ids = test_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

In [None]:
from model import TransformerEncoderViT
from model import TransformerCaptionDecoder

vocab = torch.load(f"{models}/vocab-VT.pkl", weights_only=False)
encoder = TransformerEncoderViT().to(device)
decoder = TransformerCaptionDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

encoder.load_state_dict(torch.load(f"{models}/encoder-Vit-Transformer.pth", map_location=device))
decoder.load_state_dict(torch.load(f"{models}/decoder-VT.pth", map_location=device))

In [None]:
from rouge_score import rouge_scorer

def tokens_to_string(tokens):
    if isinstance(tokens, list):
        return " ".join(tokens)
    return tokens

def compute_rouge_l(references, hypotheses):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

    scores = []

    for refs, hyp in zip(references, hypotheses):
        hyp_str = tokens_to_string(hyp)

        rouge_l_scores = []
        for ref in refs:
            ref_str = tokens_to_string(ref)
            score = scorer.score(ref_str, hyp_str)['rougeL'].fmeasure
            rouge_l_scores.append(score)

        scores.append(max(rouge_l_scores))  # best reference

    return sum(scores) / len(scores)

from jiwer import wer
def compute_wer(references, hypotheses):
    """
    Computes WER using best matching reference per hypothesis
    """
    wers = []
    ref_tokens = []
    for refs in references:
        processed_refs = []
        for ref in refs:
            if isinstance(ref, list):
                processed_refs.append(ref)        # already tokenized
            else:
                processed_refs.append(ref.split()) # string → tokens
        ref_tokens.append(processed_refs)

    # Handle hypotheses
    hyp_tokens = []
    for hyp in hypotheses:
        if isinstance(hyp, list):
            hyp_tokens.append(hyp)
        else:
            hyp_tokens.append(hyp.split())

    for refs, hyp in zip(references, hypotheses):
        ref_wers = [wer(ref, hyp) for ref in refs]
        wers.append(min(ref_wers))  # Best match

    return sum(wers) / len(wers)

from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
def clean_caption(tokens, special_tokens={"<start>", "<end>", "<pad>", "<unk>"}):
    return [w for w in tokens if w not in special_tokens]

def compute_bleu_scores(
    encoder,
    decoder,
    dataloader,
    vocab,
    device,
    decode_fn,          # greedy or beam decode function
    max_len=30
):
    """
    references: list of list of reference captions
    hypotheses: list of predicted captions
    """
    encoder.eval()
    decoder.eval()

    references = []   # list of list of list of words
    hypotheses = []   # list of list of words

    for images, captions, ids in dataloader:
        images = images.to(device)

        for i in range(images.size(0)):
            image = images[i]

            # ---- generate caption ----
            pred_sentence = decode_fn(
                image=image,
                encoder=encoder,
                decoder=decoder,
                vocab=vocab,
                max_len=max_len
            )

            pred_tokens = clean_caption(pred_sentence.split())
            hypotheses.append(pred_tokens)

            # ---- SINGLE reference caption ----
            cap_tokens = captions[i].tolist()  # [max_len]
            ref_words = [
                vocab.idx2word[idx]
                for idx in cap_tokens
                if idx not in (
                    vocab.word2idx["<start>"],
                    vocab.word2idx["<end>"],
                    vocab.word2idx["<pad>"],
                    vocab.word2idx["<unk>"]
                )
            ]

            references.append([
                clean_caption(nltk.word_tokenize(c.lower()))
                for c in image_captions[ids[i]]
                ])
            references.append([ref_words])  # <-- note: list of list

     # ---- BLEU scores ----
    # bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
    # bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    # bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

    return {
        # "BLEU-1": bleu1,
        # "BLEU-2": bleu2,
        # "BLEU-3": bleu3,
        "BLEU-4": bleu4
    }

In [None]:
@torch.no_grad()
def greedy_decode_transformer(image, encoder, decoder, vocab, max_len=30):
    encoder.eval()
    decoder.eval()

    bos = vocab.word2idx["<start>"]
    eos = vocab.word2idx["<end>"]

    image = image.unsqueeze(0)
    memory = encoder(image)

    generated = torch.tensor([[bos]], device=image.device)

    for _ in range(max_len):
        logits = decoder(memory, generated)
        next_token = logits[:, -1, vocab.word2idx["<unk>"]] = -1e9

        generated = torch.cat([generated, next_token], dim=1)

        if next_token.item() == eos:
            break

    words = []
    for idx in generated.squeeze(0).tolist():
        w = vocab.idx2word[idx]
        if w in ("<start>", "<end>", "<pad>", "<unk>"):
            continue
        words.append(w)

    return " ".join(words)
        


In [None]:
bleu_scores = compute_bleu_scores(
    encoder=encoder,
    decoder=decoder,
    dataloader=test_loader,
    vocab=vocab,
    device=device,
    decode_fn=greedy_decode_transformer,
    max_len=30
)

print(bleu_scores)

In [None]:
import torch
import torch.nn.functional as F
import math

def recall_at_k(similarity, k):
    """
    similarity: [N, N] similarity matrix
    """
    topk = similarity.topk(k, dim=1).indices
    targets = torch.arange(similarity.size(0)).unsqueeze(1).to(similarity.device)
    correct = (topk == targets).any(dim=1)
    return correct.float().mean().item()

@torch.no_grad()
def extract_image_embeddings(loader, encoder, decoder, device):
    encoder.eval()
    decoder.eval()

    all_embs = []
    all_ids = []

    for images, _, image_ids in tqdm.tqdm(loader):
        images = images.to(device)

        enc_out = encoder(images)              # [B, 196, 768]
        enc_out = decoder.enc_proj(enc_out)    # ✅ project to 256

        img_emb = enc_out.mean(dim=1)          # [B, 256]
        img_emb = F.normalize(img_emb, dim=1)

        all_embs.append(img_emb.cpu())
        all_ids.extend(image_ids)

    return torch.cat(all_embs, dim=0), all_ids


import torch.nn as nn

@torch.no_grad()
def extract_text_embeddings(encoder, decoder, loader, vocab, device):
    encoder.eval()
    decoder.eval()

    all_embs = []
    all_ids = []

    pad_idx = vocab.word2idx["<pad>"]

    for images, captions, image_ids in tqdm.tqdm(loader):
        images = images.to(device)
        captions = captions.to(device)

        # ---- Encoder ----
        encoder_out = encoder(images)  # [B, 196, 768]
        memory = decoder.enc_proj(encoder_out)  # [B, 196, D]

        # ---- Decoder embeddings (MATCH TRAINING) ----
        tgt = decoder.embed(captions) * math.sqrt(decoder.embed.embedding_dim)
        tgt = decoder.pos(tgt)

        tgt_mask = nn.Transformer.generate_square_subsequent_mask(
            captions.size(1)
        ).to(device)

        padding_mask = captions == pad_idx

        hidden = decoder.decoder(
            tgt=tgt,
            memory=memory,
            tgt_mask=tgt_mask,
            tgt_key_padding_mask=padding_mask
        )  # [B, T, D]

        # ---- last non-pad token ----
        lengths = (~padding_mask).sum(dim=1) - 1
        sent_emb = hidden[torch.arange(hidden.size(0)), lengths]

        sent_emb = F.normalize(sent_emb, dim=1)

        all_embs.append(sent_emb.cpu())
        all_ids.extend(image_ids)

    return torch.cat(all_embs, dim=0), all_ids

In [None]:

def image_to_text_retrieval(image_emb, text_emb, image_ids, text_ids, batch_size=512):
    image_emb = image_emb.cuda()
    text_emb  = text_emb.cuda()

     # build ID → index map for text embeddings
    text_id_to_index = {tid: idx for idx, tid in enumerate(text_ids)}

    N = image_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        img_batch = image_emb[i:i+batch_size]           # [B, D]
        sim = img_batch @ text_emb.T                     # [B, N]

        batch_image_ids = image_ids[i:i+batch_size]
        
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j, img_id in enumerate(batch_image_ids):
            gt_index = text_id_to_index[img_id]

            rank = (sorted_idx[j] == gt_index).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks, device=image_emb.device)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def text_to_image_retrieval(image_emb, text_emb, image_ids, text_ids, batch_size=512):
    text_emb  = text_emb.cuda()
    image_emb = image_emb.cuda()

    # build ID → index map for images
    image_id_to_index = {iid: idx for idx, iid in enumerate(image_ids)}

    N = text_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        txt_batch = text_emb[i:i+batch_size]     # [B, D]
        sim = txt_batch @ image_emb.T             # [B, N_image]

        batch_text_ids = text_ids[i:i+batch_size]
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j, txt_id in enumerate(batch_text_ids):
            gt_index = image_id_to_index[txt_id]  # ✅ correct GT image

            rank = (sorted_idx[j] == gt_index).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks, device=text_emb.device)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def text_to_text_retrieval(text_emb, text_ids, batch_size=512):
    text_emb = text_emb.cuda()

    # build image_id → list of text indices
    id_to_indices = {}
    for idx, img_id in enumerate(text_ids):
        id_to_indices.setdefault(img_id, []).append(idx)

    N = text_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        txt_batch = text_emb[i:i+batch_size]     # [B, D]
        sim = txt_batch @ text_emb.T              # [B, N]

        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            query_idx = i + j
            img_id = text_ids[query_idx]

            # valid GT indices = same image_id, excluding itself
            gt_indices = [idx for idx in id_to_indices[img_id] if idx != query_idx]

            if len(gt_indices) == 0:
                continue  # only one caption → skip

            rank = min(
                (sorted_idx[j] == gt).nonzero(as_tuple=True)[0].item()
                for gt in gt_indices
            )
            ranks.append(rank)

    ranks = torch.tensor(ranks, device=text_emb.device)

    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def image_to_image_retrieval(image_emb, image_ids, batch_size=512):
    """
    image_emb : Tensor [N, D]
    image_ids : list of image_ids
    """

    image_emb = image_emb.cuda()

    # build image_id → list of indices
    id_to_indices = {}
    for idx, img_id in enumerate(image_ids):
        id_to_indices.setdefault(img_id, []).append(idx)

    N = image_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        img_batch = image_emb[i:i+batch_size]    # [B, D]
        sim = img_batch @ image_emb.T             # [B, N]

        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            query_idx = i + j
            img_id = image_ids[query_idx]

            # same ID images excluding itself
            gt_indices = [idx for idx in id_to_indices[img_id] if idx != query_idx]

            if len(gt_indices) == 0:
                continue  # only one image per ID → skip

            rank = min(
                (sorted_idx[j] == gt).nonzero(as_tuple=True)[0].item()
                for gt in gt_indices
            )
            ranks.append(rank)

    if len(ranks) == 0:
        return {"R@1": 0.0, "R@5": 0.0, "R@10": 0.0}

    ranks = torch.tensor(ranks, device=image_emb.device)

    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

In [None]:
@torch.no_grad()
def generate_caption_transformer(image, encoder, decoder, vocab, max_len=30):
    encoder.eval()
    decoder.eval()

    bos = vocab.word2idx["<start>"]
    eos = vocab.word2idx["<end>"]

    image = image.unsqueeze(0).to(device)
    memory = encoder(image)  # [1, N, D]

    generated = torch.tensor([[bos]], device=device)

    for _ in range(max_len - 1):
        logits = decoder(memory, generated)      # [1, t, V]
        next_token = logits[:, -1].argmax(dim=-1, keepdim=True)  # [1,1]
        generated = torch.cat([generated, next_token], dim=1)

        if next_token.item() == eos:
            break

    words = []
    for idx in generated.squeeze(0).tolist():
        w = vocab.idx2word[idx]
        if w in ("<start>", "<pad>"):
            continue
        if w == "<end>":
            break
        words.append(w)

    return " ".join(words)


In [None]:
image_emb, image_ids  = extract_image_embeddings(test_loader, encoder, decoder, device)
text_emb, text_ids = extract_text_embeddings(encoder, decoder, test_loader, vocab, device)

In [None]:
print("Sample image_ids:", image_ids[:5])
print("Sample text_ids :", text_ids[:5])
print("image_ids unique:", len(set(image_ids)))
print("text_ids unique :", len(set(text_ids)))

In [None]:
print("Image → Text:", image_to_text_retrieval(image_emb, text_emb, image_ids, text_ids))
print("Text → Image:", text_to_image_retrieval(image_emb, text_emb, image_ids, text_ids))
print("Text → Text :", text_to_text_retrieval(text_emb, text_ids))
print("Image → Image:", image_to_image_retrieval(image_emb, image_ids))