In [3]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

cuda


In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


def collate_fn(batch):
    images, captions = zip(*batch)

    images = torch.stack(images, 0)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions

from torch.utils.data import DataLoader
from flickr_dataset  import FlickrDataset 

train_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn= collate_fn
)

print(len(train_dataset))
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

In [1]:
IMAGES_PATH = "../phase_1/data/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = "../phase_1/data/flickr8k/captions.txt"  # Caption file
TEST_IMAGES_PATH = "../phase_1/test copy/"  # Directory with test images

In [5]:
from model import TransformerEncoderViT
from model import TransformerDecoder

vocab = torch.load("models/vocab.pkl", weights_only=False)
encoder = TransformerEncoderViT(256).to(device)
decoder = TransformerDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

encoder.load_state_dict(torch.load("models/encoder.pth", map_location=device))
decoder.load_state_dict(torch.load("models/decoder.pth", map_location=device))

<All keys matched successfully>

In [6]:
from rouge_score import rouge_scorer

def tokens_to_string(tokens):
    if isinstance(tokens, list):
        return " ".join(tokens)
    return tokens

def compute_rouge_l(references, hypotheses):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

    scores = []

    for refs, hyp in zip(references, hypotheses):
        hyp_str = tokens_to_string(hyp)

        rouge_l_scores = []
        for ref in refs:
            ref_str = tokens_to_string(ref)
            score = scorer.score(ref_str, hyp_str)['rougeL'].fmeasure
            rouge_l_scores.append(score)

        scores.append(max(rouge_l_scores))  # best reference

    return sum(scores) / len(scores)

from jiwer import wer
def compute_wer(references, hypotheses):
    """
    Computes WER using best matching reference per hypothesis
    """
    wers = []
    ref_tokens = []
    for refs in references:
        processed_refs = []
        for ref in refs:
            if isinstance(ref, list):
                processed_refs.append(ref)        # already tokenized
            else:
                processed_refs.append(ref.split()) # string â†’ tokens
        ref_tokens.append(processed_refs)

    # Handle hypotheses
    hyp_tokens = []
    for hyp in hypotheses:
        if isinstance(hyp, list):
            hyp_tokens.append(hyp)
        else:
            hyp_tokens.append(hyp.split())

    for refs, hyp in zip(references, hypotheses):
        ref_wers = [wer(ref, hyp) for ref in refs]
        wers.append(min(ref_wers))  # Best match

    return sum(wers) / len(wers)

from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
def clean_caption(tokens, special_tokens={"<start>", "<end>", "<pad>", "<unk>"}):
    return [w for w in tokens if w not in special_tokens]

def compute_bleu_scores(
    encoder,
    decoder,
    dataloader,
    vocab,
    device,
    decode_fn,          # greedy or beam decode function
    max_len=30
):
    """
    references: list of list of reference captions
    hypotheses: list of predicted captions
    """
    encoder.eval()
    decoder.eval()

    references = []   # list of list of list of words
    hypotheses = []   # list of list of words

    for images, captions in dataloader:
        images = images.to(device)

        for i in range(images.size(0)):
            image = images[i]

            # ---- generate caption ----
            pred_sentence = decode_fn(
                image=image,
                encoder=encoder,
                decoder=decoder,
                vocab=vocab,
                max_len=max_len
            )

            pred_tokens = clean_caption(pred_sentence.split())
            hypotheses.append(pred_tokens)

            # ---- SINGLE reference caption ----
            cap_tokens = captions[i].tolist()  # [max_len]
            ref_words = [
                vocab.idx2word[idx]
                for idx in cap_tokens
                if idx not in (
                    vocab.word2idx["<start>"],
                    vocab.word2idx["<end>"],
                    vocab.word2idx["<pad>"],
                    vocab.word2idx["<unk>"]
                )
            ]

            references.append([ref_words])  # <-- note: list of list

     # ---- BLEU scores ----
    # bleu1 = corpus_bleu(references, hypotheses, weights=(1, 0, 0, 0))
    # bleu2 = corpus_bleu(references, hypotheses, weights=(0.5, 0.5, 0, 0))
    # bleu3 = corpus_bleu(references, hypotheses, weights=(0.33, 0.33, 0.33, 0))
    bleu4 = corpus_bleu(references, hypotheses, weights=(0.25, 0.25, 0.25, 0.25))

    return {
        # "BLEU-1": bleu1,
        # "BLEU-2": bleu2,
        # "BLEU-3": bleu3,
        "BLEU-4": bleu4
    }

In [7]:
@torch.no_grad()
def greedy_decode_transformer(image, encoder, decoder, vocab, max_len=30):
    encoder.eval()
    decoder.eval()

    bos = vocab.word2idx["<start>"]
    eos = vocab.word2idx["<end>"]

    image = image.unsqueeze(0)
    memory = encoder(image)

    generated = torch.tensor([[bos]], device=image.device)

    for _ in range(max_len):
        logits = decoder(memory, generated)
        next_token = logits[:, -1].argmax(dim=-1, keepdim=True)
        generated = torch.cat([generated, next_token], dim=1)

        if next_token.item() == eos:
            break

    words = []
    for idx in generated.squeeze(0).tolist():
        w = vocab.idx2word[idx]
        if w in ("<start>", "<end>", "<pad>", "<unk>"):
            continue
        words.append(w)

    return " ".join(words)
        


In [8]:

bleu_scores = compute_bleu_scores(
    encoder=encoder,
    decoder=decoder,
    dataloader=train_loader,
    vocab=vocab,
    device=device,
    decode_fn=greedy_decode_transformer,
    max_len=30
)

print(bleu_scores)

NameError: name 'train_loader' is not defined

In [None]:
import torch
import torch.nn.functional as F

def recall_at_k(similarity, k):
    """
    similarity: [N, N] similarity matrix
    """
    topk = similarity.topk(k, dim=1).indices
    targets = torch.arange(similarity.size(0)).unsqueeze(1).to(similarity.device)
    correct = (topk == targets).any(dim=1)
    return correct.float().mean().item()

def extract_image_embeddings(dataloader, encoder, device):
    encoder.eval()
    image_embeddings = []

    with torch.no_grad():
        for images, _  in dataloader:
            images = images.to(device)
            patch_emb = encoder(images)               # [B, 196, 256]
            feats = patch_emb.mean(dim=1)            # [B, 256]
            feats = F.normalize(feats, dim=1)
            image_embeddings.append(feats)

    return torch.cat(image_embeddings, dim=0)       # [N, embed_size]

@torch.no_grad()
def extract_text_embeddings(
    encoder,
    decoder,
    dataloader,
    vocab,
    device
):
    encoder.eval()
    decoder.eval()

    all_embeddings = []
    all_ids = []

    pad_idx = vocab.word2idx["<pad>"]

    for images, captions in dataloader:
        images = images.to(device)
        captions = captions.to(device)

        # ---- Encode images ----
        memory = encoder(images)  # [B, N, D]

        # ---- Transformer decoder forward (replicates decoder internals) ----
        B, T = captions.shape
        positions = torch.arange(T, device=device).unsqueeze(0)

        x = decoder.embed(captions) + decoder.pos_embed(positions)
        x = x * math.sqrt(decoder.embed_size)

        causal_mask = torch.triu(
            torch.ones(T, T, device=device), diagonal=1
        ).bool()

        padding_mask = captions == pad_idx

        out = decoder.decoder(
            tgt=x,
            memory=memory,
            tgt_mask=causal_mask,
            tgt_key_padding_mask=padding_mask
        )  # [B, T, D]

        # ---- extract last valid token per sentence ----
        lengths = (~padding_mask).sum(dim=1) - 1
        sent_emb = out[torch.arange(B), lengths]  # [B, D]

        sent_emb = F.normalize(sent_emb, dim=1)

        all_embeddings.append(sent_emb.cpu())
        all_ids.extend(range(len(sent_emb)))

    return torch.cat(all_embeddings, dim=0), all_ids

def image_to_text_retrieval(image_emb, text_emb, batch_size=512):
    image_emb = image_emb.cuda()
    text_emb  = text_emb.cuda()

    N = image_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        img_batch = image_emb[i:i+batch_size]           # [B, D]
        sim = img_batch @ text_emb.T                     # [B, N]

        gt = torch.arange(i, min(i+batch_size, N)).cuda()
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def text_to_image_retrieval(image_emb, text_emb, batch_size=512):
    image_emb = image_emb.cuda()
    text_emb  = text_emb.cuda()

    N = text_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        txt_batch = text_emb[i:i+batch_size]           # [B, D]
        sim = txt_batch @ image_emb.T                     # [B, N]

        gt = torch.arange(i, min(i+batch_size, N)).cuda()
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

def text_to_text_retrieval(text_emb, batch_size=512):
    text_emb  = text_emb.cuda()

    N = text_emb.size(0)
    ranks = []

    for i in range(0, N, batch_size):
        txt_batch = text_emb[i:i+batch_size]           # [B, D]
        sim = txt_batch @ text_emb.T                     # [B, N]
        sim.fill_diagonal_(-1)

        gt = torch.arange(i, min(i+batch_size, N)).cuda()
        sorted_idx = sim.argsort(dim=1, descending=True)

        for j in range(sorted_idx.size(0)):
            rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
            ranks.append(rank)

    ranks = torch.tensor(ranks)
    return {
        "R@1":  (ranks < 1).float().mean().item(),
        "R@5":  (ranks < 5).float().mean().item(),
        "R@10": (ranks < 10).float().mean().item()
    }

# def image_to_image_retrieval(image_emb, batch_size=512):
#     image_emb  = image_emb.cuda()

#     N = image_emb.size(0)
#     ranks = []

#     for i in range(0, N, batch_size):
#         img_batch = image_emb[i:i+batch_size]           # [B, D]
#         sim = img_batch @ image_emb.T                     # [B, N]
#         sim.fill_diagonal_(-1)

#         gt = torch.arange(i, min(i+batch_size, N)).cuda()
#         sorted_idx = sim.argsort(dim=1, descending=True)

#         for j in range(sorted_idx.size(0)):
#             rank = (sorted_idx[j] == gt[j]).nonzero(as_tuple=True)[0].item()
#             ranks.append(rank)

#     ranks = torch.tensor(ranks)
#     return {
#         "R@1":  (ranks < 1).float().mean().item(),
#         "R@5":  (ranks < 5).float().mean().item(),
#         "R@10": (ranks < 10).float().mean().item()
#     }

def image_to_image_retrieval(image_emb):
    sim = image_emb @ image_emb.t()

    # Remove self-matching
    sim.fill_diagonal_(-1)

    return {
        "R@1": recall_at_k(sim, 1),
        "R@5": recall_at_k(sim, 5),
        "R@10": recall_at_k(sim, 10),
    }

In [None]:
@torch.no_grad()
def generate_caption_transformer(image, encoder, decoder, vocab, max_len=30):
    encoder.eval()
    decoder.eval()

    bos = vocab.word2idx["<start>"]
    eos = vocab.word2idx["<end>"]

    image = image.unsqueeze(0).to(device)
    memory = encoder(image)  # [1, N, D]

    generated = torch.tensor([[bos]], device=device)

    for _ in range(max_len - 1):
        logits = decoder(memory, generated)      # [1, t, V]
        next_token = logits[:, -1].argmax(dim=-1, keepdim=True)  # [1,1]
        generated = torch.cat([generated, next_token], dim=1)

        if next_token.item() == eos:
            break

    words = []
    for idx in generated.squeeze(0).tolist():
        w = vocab.idx2word[idx]
        if w in ("<start>", "<pad>"):
            continue
        if w == "<end>":
            break
        words.append(w)

    return " ".join(words)


In [None]:
#  test the model on a few images
import os
import matplotlib.pyplot as plt
from PIL import Image

for file_name in os.listdir(TEST_IMAGES_PATH)[:20]:
    # Load image
    img_path = f"{TEST_IMAGES_PATH}/{file_name}"
    image = Image.open(img_path).convert("RGB")

    image = transform(image)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()
    
    caption = generate_caption_transformer(image, encoder, decoder, vocab)
    print(f"Image: {file_name}\nCaption: {caption}\n")
