In [None]:
!pip install nltk

In [None]:
import torch
from torch.nn.utils.rnn import pad_sequence
from torch.utils.data import Dataset
from PIL import Image
import torchvision.transforms as transforms
from torch.utils.data import DataLoader
from flickr_dataset import FlickrDataset 
from model import EncoderCNN, DecoderRNN
import torch.nn as nn
from vocabulary_class import Vocabulary
from rouge_score import rouge_scorer
from jiwer import wer
from nltk.translate.bleu_score import corpus_bleu, SmoothingFunction

In [None]:
device = "cuda" if torch.cuda.is_available() else "cpu"
print(device)

In [None]:
datasets = "..\\..\\datasets\\"
models = "..\\..\\models\\"
caption_model = torch.load(f"{models}/caption_features_flickr8k.pt")

IMAGES_PATH = f"{datasets}/flickr8k/images"  # Directory with training images
CAPTIONS_PATH = f"{datasets}/flickr8k/captions.txt"  # Caption file
TEST_IMAGES_PATH = "..\\test_images"  # Directory with test images

In [None]:
import tqdm
import nltk 
from collections import Counter
from vocabulary_class import Vocabulary
nltk.download('punkt_tab')
import json

tokens = []
counter = Counter()

import csv

def build_vocab(json_path, threshold=5, limit=None):
    
    counter = Counter()
    image_captions = {}
    count =0
    with open(CAPTIONS_PATH, "r", encoding="utf-8") as f:
        reader = csv.reader(f)
        next(reader)  # skip header: image,caption

        for row in reader:
            if len(row) < 2:
                continue
            img_name, caption = row
            if img_name not in image_captions:
                    image_captions[img_name] = []
            image_captions[img_name].append(caption)

            caption = caption.lower()
            tokens = nltk.tokenize.word_tokenize(caption)
            counter.update(tokens)
            count +=1
            if limit and count >= limit:
                break
    
    vocab = Vocabulary()
    for word, cnt in counter.items():
        if cnt >= threshold:
            vocab.add_word(word)
    
    return vocab, image_captions

In [None]:
vocab, image_captions = build_vocab(CAPTIONS_PATH, threshold=5)
print("Total vocabulary size:", len(vocab))

In [None]:
len(image_captions)

In [None]:
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(), 
    transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)),
])

In [None]:
def collate_fn(batch):
    images, captions = zip(*batch)

    images = torch.stack(images, 0)

    lengths = torch.tensor([len(c) for c in captions])

    captions = pad_sequence(
        captions,
        batch_first=True,
        padding_value=0
    )

    return images, captions, lengths

In [None]:
train_dataset = FlickrDataset(
    root=IMAGES_PATH,
    captions_path=CAPTIONS_PATH,
    vocab=vocab,
    transform=transform,
    max_samples=None
)

train_loader = DataLoader(
    dataset=train_dataset,
    batch_size=32,
    shuffle=False,
    num_workers=0,
    collate_fn= collate_fn
)

In [None]:
print(len(train_dataset))

In [None]:
image, caption = train_dataset[0]

print(type(image))
print(image.shape)        # after transform
print(caption)
print(len(caption))

In [None]:
encoder = EncoderCNN(embed_size=256).to(device)
decoder = DecoderRNN(embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

criterion = nn.CrossEntropyLoss(
    ignore_index=vocab.word2idx["<pad>"],
    label_smoothing=0.1
)

params = list(encoder.parameters()) + list(decoder.parameters())
optimizer = torch.optim.Adam([
    {"params": encoder.parameters(), "lr": 1e-4},   # ViT: small LR
    {"params": decoder.parameters(), "lr": 1e-3},   # LSTM: larger LR
])

In [None]:
train_losses = []
for epoch in range(20):
    total_train_loss = 0
    for images, captions, lengths in tqdm.tqdm(train_loader):

        if epoch < 3:
            for p in encoder.parameters():
                p.requires_grad = False
        else:
            for p in encoder.parameters():
                p.requires_grad = True

        images = images.to(device)
        captions = captions.to(device)

        optimizer.zero_grad()
        
        features = encoder(images)
        outputs = decoder(features, captions, lengths)   # <-- IMPORTANT
        targets = captions[:, 1:]                       # shift left
        outputs = outputs[:, :-1, :]                    # align prediction
        loss = criterion(outputs.reshape(-1, len(vocab)),
                            targets.reshape(-1))

        total_train_loss += loss.item()

        loss.backward()
        torch.nn.utils.clip_grad_norm_(decoder.parameters(), max_norm=5.0)
        optimizer.step()
    avg_train_loss = total_train_loss / len(train_loader)
    train_losses.append(avg_train_loss)
    print(f"Epoch {epoch}: Train={avg_train_loss:.4f}")

In [None]:
def generate_caption(image, encoder, decoder, vocab):
    encoder.eval()
    decoder.eval()
    image = image.unsqueeze(0).to(device)
    feature = encoder(image)              # [1, 256]
    feature = feature.unsqueeze(1)        # [1, 1, 256]

    # 2. Start sequence with <start> token
    start_token = vocab.word2idx["<start>"]
    end_token = vocab.word2idx["<end>"]

    sampled_ids = []
    inputs = torch.LongTensor([[start_token]]).to(image.device)

    # 3. FIRST STEP: concatenate image feature + embedding(<start>)
    embeddings = decoder.embed(inputs)     # [1,1,256]
    lstm_input = torch.cat((feature, embeddings), dim=1)  # [1,2,256]

    hiddens, states = decoder.lstm(lstm_input)

    outputs = decoder.linear(hiddens[:, -1, :])
    predicted = outputs.argmax(dim=1).item()
    sampled_ids.append(predicted)

    # 4. NEXT STEPS: only feed predicted tokens (NO concatenation!)
    inputs = torch.LongTensor([[predicted]]).to(image.device)

    for _ in range(20):
        embeddings = decoder.embed(inputs)  # [1,1,256]

        hiddens, states = decoder.lstm(embeddings, states)
        outputs = decoder.linear(hiddens[:, -1, :])
        
        predicted = outputs.argmax(dim=1).item()
        sampled_ids.append(predicted)
        
        if predicted == end_token:
            break

        inputs = torch.LongTensor([[predicted]]).to(image.device)

    words = [vocab.idx2word[id] for id in sampled_ids]
    return " ".join(words)

In [None]:
torch.save(encoder.state_dict(), f"{models}//encoder-resnet50.pth")
torch.save(decoder.state_dict(), f"{models}//decoder-rrnn-lstm.pth")
torch.save(vocab, f"{models}//vocab.pkl")

In [None]:

vocab = torch.load(f"{models}/vocab.pkl", weights_only=False)
encoder = EncoderCNN(256).to(device)
decoder = DecoderRNN(
    256,
    512,
    len(vocab)
).to(device)

encoder.load_state_dict(torch.load(f"{models}/encoder-resnet50.pth", weights_only=True, map_location=device))
decoder.load_state_dict(torch.load(f"{models}/decoder-rrnn-lstm.pth", weights_only=True, map_location=device))

In [None]:
def tokens_to_string(tokens):
    if isinstance(tokens, list):
        return " ".join(tokens)
    return tokens

def compute_rouge_l(references, hypotheses):
    scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=True)

    scores = []

    for refs, hyp in zip(references, hypotheses):
        hyp_str = tokens_to_string(hyp)

        rouge_l_scores = []
        for ref in refs:
            ref_str = tokens_to_string(ref)
            score = scorer.score(ref_str, hyp_str)['rougeL'].fmeasure
            rouge_l_scores.append(score)

        scores.append(max(rouge_l_scores))  # best reference

    return sum(scores) / len(scores)

def compute_wer(references, hypotheses):
    """
    Computes WER using best matching reference per hypothesis
    """
    wers = []
    ref_tokens = []
    for refs in references:
        processed_refs = []
        for ref in refs:
            if isinstance(ref, list):
                processed_refs.append(ref)        # already tokenized
            else:
                processed_refs.append(ref.split()) # string → tokens
        ref_tokens.append(processed_refs)

    # Handle hypotheses
    hyp_tokens = []
    for hyp in hypotheses:
        if isinstance(hyp, list):
            hyp_tokens.append(hyp)
        else:
            hyp_tokens.append(hyp.split())

    for refs, hyp in zip(references, hypotheses):
        ref_wers = [wer(ref, hyp) for ref in refs]
        wers.append(min(ref_wers))  # Best match

    return sum(wers) / len(wers)

def compute_bleu4(references, hypotheses):
    """
    references: list of list of reference captions
    hypotheses: list of predicted captions
    """
    ref_tokens = []
    for refs in references:
        processed_refs = []
        for ref in refs:
            if isinstance(ref, list):
                processed_refs.append(ref)        # already tokenized
            else:
                processed_refs.append(ref.split()) # string → tokens
        ref_tokens.append(processed_refs)

    # Handle hypotheses
    hyp_tokens = []
    for hyp in hypotheses:
        if isinstance(hyp, list):
            hyp_tokens.append(hyp)
        else:
            hyp_tokens.append(hyp.split())


    smoothie = SmoothingFunction().method4

    bleu4 = corpus_bleu(
        ref_tokens,
        hyp_tokens,
        weights=(0.25, 0.25, 0.25, 0.25),
        # weights=(0, 0, 0, 0),
        smoothing_function=smoothie
    )

    return bleu4

In [None]:
# Bluen score caluclation
references = []
hypotheses = []
encoder.eval()
decoder.eval()
with torch.no_grad():
    for images, captions, lengths in tqdm.tqdm(train_loader):
        images = images.to(device)
        captions = captions.to(device)

        features = encoder(images)
        outputs = decoder(features, captions, lengths)   # <-- IMPORTANT

        _, predicted = outputs.max(2)  # [B, max_len]

        for i in range(captions.size(0)):
            ref = []
            for j in range(1, lengths[i]):  # remove <start> token
                ref.append(vocab.idx2word[captions[i][j].item()])
            references.append([ref])

            hyp = []
            for j in range(1, lengths[i]):  # remove <start> token
                hyp.append(vocab.idx2word[predicted[i][j-1].item()])
            hypotheses.append(hyp)
        
bleu4 = compute_bleu4(references, hypotheses)
print(f"blue 4 scores:: {bleu4:.4f}")

rougue_score = compute_rouge_l(references, hypotheses)
print(f"Rouge Score: {rougue_score:.4f}")

wer_score = compute_wer(references, hypotheses)
print(f"WER Score: {wer_score:.4f}")

In [None]:
#  test the model on a few images
import os
import matplotlib.pyplot as plt
from PIL import Image

for file_name in os.listdir(TEST_IMAGES_PATH)[:20]:
    # Load image
    img_path = f"{TEST_IMAGES_PATH}/{file_name}"
    image = Image.open(img_path).convert("RGB")

    image = transform(image)
    plt.imshow(image.permute(1, 2, 0).cpu().numpy())
    plt.axis('off')
    plt.show()
    
    caption = generate_caption(image, encoder, decoder, vocab)
    print(f"Image: {file_name}\nCaption: {caption}\n")