In [None]:
!pip install pycocoevalcap

In [None]:
import torch
import torchvision.transforms as transforms
import torch.nn as nn
import nltk 
import tqdm
import json
import csv
import numpy as np
import torch.nn.functional as F

from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from PIL import Image
from torch.utils.data import DataLoader
from vocabulary_class import Vocabulary
from rouge_score import rouge_scorer
from jiwer import wer
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction
from collections import Counter
from vocabulary_class import Vocabulary
from flickr_dataset import FlickrDataset
from model import TransformerEncoderViT, TransformerCaptionDecoder
from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
from pycocoevalcap.cider.cider import Cider

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")

IMAGES_PATH = f"{datasets}/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt"  # Caption file
TEST_IMAGES_PATH = "..\\test_images"  # Directory with test images

In [None]:
tokens = []
counter = Counter()

def build_vocab(json_path, threshold=5, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)  # skip header: image,caption

        for row in reader:
            if len(row) < 2:
                continue
            img_name, caption = row
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = nltk.tokenize.word_tokenize(caption)
            counter.update(tokens)
            count +=1
            if limit and count >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab, image_captions

In [None]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
def collate_fn(batch):
    images, captions, image_ids = zip(*batch)

    images = torch.stack(images)

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions, image_ids

In [None]:
train_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=16,
    shuffle=True,
    num_workers=0,
    collate_fn=collate_fn
)

print(len(train_dataset))
image, caption, _ = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

In [None]:
encoder = TransformerEncoderViT().to(device)
decoder = TransformerCaptionDecoder(embed_size=256, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(ignore_index=vocab.word2idx["<pad>"])

params = list(filter(lambda p: p.requires_grad, encoder.parameters())) + \
         list(filter(lambda p: p.requires_grad, decoder.parameters()))

optimizer = torch.optim.AdamW(
    params,
    lr=3e-4,
    weight_decay=1e-2
)

In [None]:
for epoch in range(10):
    total_train_loss = 0
    for images, captions, image_ids in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        captions_in = captions[:, :-1]
        targets     = captions[:, 1:]

        optimizer.zero_grad(set_to_none=True)

        memory = encoder(images)
        logits = decoder(memory, captions_in)

        loss = criterion(
            logits.reshape(-1, logits.size(-1)),
            targets.reshape(-1)
        )

        loss.backward()
        torch.nn.utils.clip_grad_norm_(params, 1.0)
        optimizer.step()

        total_train_loss += loss.item()

    print(f"Epoch {epoch}: {total_train_loss / len(train_loader):.4f}")


In [None]:
torch.save(encoder.state_dict(), f"{models}/encoder-Vit-Transformer.pth")
torch.save(decoder.state_dict(), f"{models}/decoder-VT.pth")
torch.save(vocab, f"{models}/vocab-VT.pkl")

In [None]:
eval_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    split="val",   # or test list
    max_samples=1000
)
eval_loader = DataLoader(
    dataset=eval_dataset,
    batch_size=16,
    shuffle=False,
    num_workers=0,
    collate_fn=collate_fn
)

In [None]:
@torch.no_grad()
def generate_caption_beam(
    image,
    encoder,
    decoder,
    vocab,
    beam_size=5,
    max_len=30,
    length_penalty=0.7
):
    encoder.eval()
    decoder.eval()

    bos = vocab.word2idx["<start>"]
    eos = vocab.word2idx["<end>"]

    image = image.unsqueeze(0).to(device)
    memory = encoder(image)  # [1, N, D]

    beams = [(torch.tensor([[bos]], device=device), 0.0)]

    for _ in range(max_len):
        all_candidates = []

        for seq, score in beams:
            if seq[0, -1].item() == eos:
                all_candidates.append((seq, score))
                continue

            logits = decoder(memory, seq)
            log_probs = F.log_softmax(logits[:, -1], dim=-1)

            topk_logp, topk_idx = log_probs.topk(beam_size)

            for k in range(beam_size):
                next_seq = torch.cat(
                    [seq, topk_idx[:, k].unsqueeze(1)], dim=1
                )

                new_score = score + topk_logp[0, k].item()

                all_candidates.append((next_seq, new_score))

        beams = sorted(all_candidates, key=lambda x: x[1] / (len(x[0][0]) ** length_penalty), reverse=True)[:beam_size]

    best_seq = beams[0][0].squeeze(0).tolist()

    caption = []
    for idx in best_seq:
        word = vocab.idx2word[idx]
        if word in ("<start>", "<pad>"):
            continue
        if word == "<end>":
            break
        caption.append(word)

    return " ".join(caption)

In [None]:
references = []   # list[list[list[str]]]
hypotheses = []   # list[list[str]]

for images, _, image_ids in eval_loader:
    for i in range(images.size(0)):
        img_name = image_ids[i]

        refs = []
        for cap in eval_dataset.image_captions[img_name]:
            refs.append(nltk.word_tokenize(cap.lower()))
        references.append(refs)

        image = image = images[i].to(device)
        hyp = generate_caption_beam(image, encoder, decoder, vocab).strip().split()
        hypotheses.append(hyp)

In [None]:
print("HYP:", hypotheses[0])
print("REF:", references[0][0])

In [None]:
print("Avg hyp length:", np.mean([len(h) for h in hypotheses]))

In [None]:
bleu1 = corpus_bleu(references, hypotheses, weights=(1,0,0,0))
bleu2 = corpus_bleu(references, hypotheses, weights=(0.5,0.5,0,0))
print("BLEU-1:", bleu1)
print("BLEU-2:", bleu2)

In [None]:
smooth = SmoothingFunction().method4

bleu4 = corpus_bleu(
    references,
    hypotheses,
    weights=(0.25, 0.25, 0.25, 0.25),
    smoothing_function=smooth
)

print(f"Smoothed BLEU-4: {bleu4:.4f}")

In [None]:
scorer = rouge_scorer.RougeScorer(["rougeL"], use_stemmer=True)
rouge_scores = []

for hyp, refs in zip(hypotheses, references):
    hyp_str = " ".join(hyp)
    best = max(
        scorer.score(" ".join(r), hyp_str)["rougeL"].fmeasure
        for r in refs
    )
    rouge_scores.append(best)

print(f"ROUGE-L: {sum(rouge_scores)/len(rouge_scores):.4f}")

In [None]:
def best_wer(hyp, refs):
    hyp_str = " ".join(hyp)
    return min(wer(" ".join(r), hyp_str) for r in refs)

wer_score = sum(
    best_wer(h, r) for h, r in zip(hypotheses, references)
) / len(hypotheses)

print(f"WER: {wer_score:.4f}")

In [None]:
refs_dict = {i: [" ".join(r) for r in refs] for i, refs in enumerate(references)}
hyps_dict = {i: [" ".join(h)] for i, h in enumerate(hypotheses)}

cider, _ = Cider().compute_score(refs_dict, hyps_dict)
print(f"CIDEr: {cider:.4f}")
