In [None]:
import torch
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")

IMAGES_PATH = f"{datasets}/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt"  # Caption file
TEST_IMAGES_PATH = "..\\test_images"  # Directory with test images

In [None]:
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms

from torch.nn.utils.rnn import pad_sequence
import torch

transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])


In [None]:
import torch.nn.functional as F

@torch.no_grad()
def generate_caption_beam(
    image,
    encoder,
    decoder,
    vocab,
    beam_size=5,
    max_len=30,
    length_penalty=0.7
):
    encoder.eval()
    decoder.eval()

    bos = vocab.word2idx["<start>"]
    eos = vocab.word2idx["<end>"]

    image = image.unsqueeze(0).to(device)
    memory = encoder(image)  # [1, N, D]

    beams = [(torch.tensor([[bos]], device=device), 0.0)]

    for _ in range(max_len):
        all_candidates = []

        for seq, score in beams:
            if seq[0, -1].item() == eos:
                all_candidates.append((seq, score))
                continue

            logits = decoder(memory, seq)
            log_probs = F.log_softmax(logits[:, -1], dim=-1)

            topk_logp, topk_idx = log_probs.topk(beam_size)

            for k in range(beam_size):
                next_seq = torch.cat(
                    [seq, topk_idx[:, k].unsqueeze(1)], dim=1
                )

                new_score = score + topk_logp[0, k].item()

                all_candidates.append((next_seq, new_score))

        beams = sorted(all_candidates, key=lambda x: x[1] / (len(x[0][0]) ** length_penalty), reverse=True)[:beam_size]

    best_seq = beams[0][0].squeeze(0).tolist()

    caption = []
    for idx in best_seq:
        word = vocab.idx2word[idx]
        if word in ("<start>", "<pad>"):
            continue
        if word == "<end>":
            break
        caption.append(word)

    return " ".join(caption)

In [None]:
from model import TransformerEncoderViT
from model import TransformerCaptionDecoder

vocab = torch.load(f"{models}/vocab-VT.pkl", weights_only=False)
encoder = TransformerEncoderViT().to(device)
decoder = TransformerCaptionDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

encoder.load_state_dict(torch.load(f"{models}/encoder-Vit-Transformer.pth", map_location=device))
decoder.load_state_dict(torch.load(f"{models}/decoder-VT.pth", map_location=device))

encoder.eval()
decoder.eval()

In [None]:
#  test the model on a few images
import os
import matplotlib.pyplot as plt
from PIL import Image
import csv

with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)
        seen = set()
        for img_name, caption in reader:
                if img_name in seen:
                    continue
                seen.add(img_name)

                img_path = f"{IMAGES_PATH}/{img_name}"
                image = Image.open(img_path).convert("RGB")

                image = transform(image)
                img = image.permute(1,2,0).cpu().numpy()
                img = (img - img.min()) / (img.max() - img.min())
                plt.imshow(img)

                plt.axis('off')
                plt.show()

                #caption = generate_caption_beam(image, encoder, decoder, vocab)
                caption = generate_caption_beam(image, encoder, decoder, vocab).strip().split()
                print(f"Image: {img_name}\nCaption: {caption}\n")
