## Import libraries and data

In [1]:
# import libraries
import os
import torch
import torch.nn as nn
import torchvision.transforms as transforms
import torchvision.models as models
from torch.utils.data import Dataset, DataLoader
from PIL import Image
import json
from torch.nn.utils.rnn import pad_sequence
import numpy as np
from tqdm import tqdm
from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction
from rouge_score import rouge_scorer
from bert_score import score as bert_score
from transformers import AutoModel, AutoTokenizer

  from .autonotebook import tqdm as notebook_tqdm


In [2]:
# load annotation
vocab_dir = "../../data/vocab"
with open(os.path.join(vocab_dir, "vocab.json"), "r", encoding="utf-8") as f:
    vocab = json.load(f)
    vocab = {k: int(v) for k, v in vocab.items()}
with open(os.path.join(vocab_dir, "idx_to_word.json"), "r", encoding="utf-8") as f:
    idx_to_word = json.load(f)
    idx_to_word = {int(k): v for k, v in idx_to_word.items()}
    inv_vocab = idx_to_word

## Define functions/classes

In [3]:
# Dataset class
class ImageCaptionDataset(Dataset):
    def __init__(self, image_dir, captions_file, vocab, transform=None):
        self.image_dir = image_dir
        self.vocab = vocab
        self.transform = transform

        self.data = []
        with open(captions_file, 'r', encoding='utf-8') as f:
            for line in f:
                img_name, caption = line.strip().split('\t')
                tokens = caption.strip().split()
                self.data.append((img_name, tokens))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        img_name, tokens = self.data[idx]
        image_path = os.path.join(self.image_dir, img_name)
        image = Image.open(image_path).convert("RGB")

        if self.transform:
            image = self.transform(image)

        caption = [self.vocab['<s>']] + [self.vocab.get(token, self.vocab['<unk>']) for token in tokens] + [self.vocab['</s>']]
        return image, torch.tensor(caption), img_name, tokens

def collate_fn(batch):
    images, captions, img_names, raw_tokens = zip(*batch)
    images = torch.stack(images)
    captions = pad_sequence(captions, batch_first=True, padding_value=vocab['<pad>'])
    return images, captions, img_names, raw_tokens

In [4]:
# Encoder CNN
class EncoderCNN(nn.Module):
    def __init__(self, encoded_size=256):
        super().__init__()
        resnet = models.resnet50(weights=models.ResNet50_Weights.IMAGENET1K_V1)
        for param in resnet.parameters():
            param.requires_grad = False
        modules = list(resnet.children())[:-2]
        self.resnet = nn.Sequential(*modules)
        self.fc = nn.Linear(2048, encoded_size)
        self.bn = nn.BatchNorm1d(encoded_size)
        self.adaptive_pool = nn.AdaptiveAvgPool2d((1, 1))

    def forward(self, images):
        features = self.resnet(images)
        features = self.adaptive_pool(features)
        features = features.view(features.size(0), -1)
        features = self.fc(features)
        features = self.bn(features)
        return features

# Attention mechanism
class Attention(nn.Module):
    def __init__(self, encoder_dim, decoder_dim, attention_dim):
        super().__init__()
        self.encoder_att = nn.Linear(encoder_dim, attention_dim)
        self.decoder_att = nn.Linear(decoder_dim, attention_dim)
        self.full_att = nn.Linear(attention_dim, 1)

    def forward(self, encoder_out, decoder_hidden):
        att1 = self.encoder_att(encoder_out)
        att2 = self.decoder_att(decoder_hidden).unsqueeze(1)
        att = self.full_att(torch.tanh(att1 + att2)).squeeze(2)
        alpha = torch.softmax(att, dim=1)
        context = (encoder_out * alpha.unsqueeze(2)).sum(dim=1)
        return context, alpha

# Decoder RNN
class DecoderRNN(nn.Module):
    def __init__(self, embed_size, hidden_size, vocab_size, encoder_dim=256, attention_dim=256):
        super().__init__()
        self.attention = Attention(encoder_dim, hidden_size, attention_dim)
        self.embedding = nn.Embedding(vocab_size, embed_size)
        self.lstm = nn.LSTMCell(embed_size + encoder_dim, hidden_size)
        self.fc = nn.Linear(hidden_size, vocab_size)
        self.dropout = nn.Dropout(0.5)
        self.hidden_size = hidden_size

    def forward(self, encoder_out, captions):
        batch_size = encoder_out.size(0)
        vocab_size = self.fc.out_features
        embeddings = self.embedding(captions)

        h, c = torch.zeros(batch_size, self.hidden_size).to(encoder_out.device), torch.zeros(batch_size, self.hidden_size).to(encoder_out.device)
        outputs = torch.zeros(batch_size, captions.size(1), vocab_size).to(encoder_out.device)

        for t in range(captions.size(1)):
            context, _ = self.attention(encoder_out.unsqueeze(1), h)
            input_lstm = torch.cat([embeddings[:, t], context], dim=1)
            h, c = self.lstm(input_lstm, (h, c))
            h = self.dropout(h)
            outputs[:, t] = self.fc(h)

        return outputs

# Combined Model
class ImageCaptionModel(nn.Module):
    def __init__(self, encoded_size=256, embed_size=256, hidden_size=512, vocab_size=len(vocab), attention_dim=256):
        super().__init__()
        self.encoder = EncoderCNN(encoded_size=encoded_size)
        self.decoder = DecoderRNN(embed_size=embed_size, hidden_size=hidden_size, vocab_size=vocab_size, 
                                 encoder_dim=encoded_size, attention_dim=attention_dim)

    def forward(self, images, captions):
        features = self.encoder(images)
        outputs = self.decoder(features, captions)
        return outputs

    @torch.no_grad()
    def generate_caption(self, image, vocab, idx_to_word, max_length=20, beam_width=5):
        self.eval()
        image = image.unsqueeze(0).to(image.device)
        features = self.encoder(image).unsqueeze(1)  # [1, 1, encoded_size]

        # Beam search
        sequences = [[[], 0.0, [features, None]]]  # [sequence, score, [features, (h, c)]]
        for _ in range(max_length):
            all_candidates = []
            for seq, score, state in sequences:
                if len(seq) > 0 and seq[-1] == vocab["</s>"]:
                    all_candidates.append([seq, score, state])
                    continue

                if len(seq) == 0:
                    token = torch.tensor([[vocab["<s>"]]], device=image.device)
                else:
                    token = torch.tensor([[seq[-1]]], device=image.device)

                embed = self.decoder.embedding(token)
                context, _ = self.decoder.attention(state[0], state[1][0] if state[1] else torch.zeros(1, self.decoder.hidden_size).to(image.device))
                input_lstm = torch.cat([embed.squeeze(1), context], dim=1)
                if state[1] is None:
                    h, c = self.decoder.lstm(input_lstm)
                else:
                    h, c = self.decoder.lstm(input_lstm, state[1])

                output = self.decoder.fc(h)
                output = torch.softmax(output, dim=-1)

                top_probs, top_indices = output.topk(beam_width)
                for i in range(beam_width):
                    next_seq = seq + [top_indices[0, i].item()]
                    next_score = score + torch.log(top_probs[0, i]).item()
                    all_candidates.append([next_seq, next_score, [features, (h, c)]])

            sequences = sorted(all_candidates, key=lambda x: x[1], reverse=True)[:beam_width]

        captions = []
        for seq, score, _ in sequences:
            caption = [idx_to_word.get(idx, "<unk>") for idx in seq if idx not in [vocab["<s>"], vocab["</s>"]]]
            captions.append(" ".join(caption))

        return captions

## Evaluation

In [5]:
# Data transforms
transform = transforms.Compose([
    transforms.Resize((224, 224)),
    transforms.ToTensor(),
    transforms.Normalize(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225])
])

# Load test dataset
test_dataset = ImageCaptionDataset('../../data/test/images', 
                                   '../../data/test/captions.txt', 
                                   vocab, transform)
test_loader = DataLoader(test_dataset, batch_size=1, shuffle=False, collate_fn=collate_fn)

# Model setup
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
model = ImageCaptionModel(encoded_size=256, embed_size=256, hidden_size=512, vocab_size=len(vocab)).to(device)

# Load trained model
model.load_state_dict(torch.load('../../models/best_base_resnet50_lstm.pth', weights_only=True, map_location=device))
model.eval()

# Initialize PhoBERT for BERTScore
phobert = AutoModel.from_pretrained("vinai/phobert-base")
tokenizer = AutoTokenizer.from_pretrained("vinai/phobert-base")

# Register PhoBERT model with bert_score
from bert_score.utils import model2layers
if "vinai/phobert-base" not in model2layers:
    model2layers["vinai/phobert-base"] = 12  # PhoBERT-base has 12 layers

# Initialize ROUGE scorer
from rouge_score import rouge_scorer
scorer = rouge_scorer.RougeScorer(['rougeL'], use_stemmer=False)

# Evaluation function
def evaluate_model(model, test_loader, vocab, idx_to_word):
    bleu4_scores = []
    rouge_l_scores = []
    bert_p_scores = []
    bert_r_scores = []
    bert_f1_scores = []

    smoothing = SmoothingFunction().method1  # For BLEU smoothing

    for images, _, img_names, raw_tokens in tqdm(test_loader, desc="Evaluating"):
        image = images[0].to(device)
        ground_truth = [' '.join(tokens) for tokens in raw_tokens]  # List of reference captions

        # Generate captions (take the top-1 caption from beam search)
        generated_captions = model.generate_caption(image, vocab, idx_to_word, max_length=20, beam_width=5)
        generated_caption = generated_captions[0]  # Use the highest-scoring caption

        # BLEU-4
        reference_tokens = [tokens for tokens in raw_tokens]
        generated_tokens = generated_caption.split()
        bleu4 = sentence_bleu(reference_tokens, generated_tokens, weights=(0.25, 0.25, 0.25, 0.25), smoothing_function=smoothing)
        bleu4_scores.append(bleu4)

        # ROUGE-L
        rouge_scores = scorer.score(ground_truth[0], generated_caption)
        rouge_l_scores.append(rouge_scores['rougeL'].fmeasure)

        # BERTScore with PhoBERT
        P, R, F1 = bert_score(
            [generated_caption],
            ground_truth,
            model_type="vinai/phobert-base",
            lang="vi",
            device=device,
            use_fast_tokenizer=True
        )
        bert_p_scores.append(P.item())
        bert_r_scores.append(R.item())
        bert_f1_scores.append(F1.item())

    # Compute average scores
    avg_bleu4 = np.mean(bleu4_scores)
    avg_rouge_l = np.mean(rouge_l_scores)
    avg_bert_p = np.mean(bert_p_scores)
    avg_bert_r = np.mean(bert_r_scores)
    avg_bert_f1 = np.mean(bert_f1_scores)

    print(f"\nEvaluation Results:")
    print(f"Average BLEU-4: {avg_bleu4:.4f}")
    print(f"Average ROUGE-L: {avg_rouge_l:.4f}")
    print(f"Average BERTScore Precision: {avg_bert_p:.4f}")
    print(f"Average BERTScore Recall: {avg_bert_r:.4f}")
    print(f"Average BERTScore F1: {avg_bert_f1:.4f}")

    return {
        "BLEU-4": avg_bleu4,
        "ROUGE-L": avg_rouge_l,
        "BERTScore_P": avg_bert_p,
        "BERTScore_R": avg_bert_r,
        "BERTScore_F1": avg_bert_f1
    }

  return self.fget.__get__(instance, owner)()
Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.


In [6]:
# Run evaluation
print(f"Evaluating on device: {device}")
results = evaluate_model(model, test_loader, vocab, inv_vocab)

Evaluating on device: cpu


Evaluating:   0%|          | 0/1395 [00:00<?, ?it/s]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Evaluating:   0%|          | 1/1395 [00:06<2:21:13,  6.08s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Evaluating:   0%|          | 2/1395 [00:10<2:01:49,  5.25s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Evaluating:   0%|          | 3/1395 [00:15<1:59:20,  5.14s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Evaluating:   0%|          | 4/1395 [00:19<1:48:08,  4.66s/it]Special tokens have been added in the vocabulary, make sure the associated word embeddings are fine-tuned or trained.
Evaluating:   0%|          | 5/1395 [00:24<1:45:52,  4.57s/it]Special tokens have been added in the vocabulary


Evaluation Results:
Average BLEU-4: 0.0373
Average ROUGE-L: 0.3536
Average BERTScore Precision: 0.4772
Average BERTScore Recall: 0.4680
Average BERTScore F1: 0.4722



