In [None]:

import sys

sys.path.append("../src")

In [None]:

from statistics import mean

import evaluate
import torch
from torch.utils.data import DataLoader
from tqdm import tqdm
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from dataset import HMDataset

In [None]:
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

In [None]:
model = AutoModelForSeq2SeqLM.from_pretrained("./logs/checkpoint-100")
model = model.to(device)

tokenizer = AutoTokenizer.from_pretrained("facebook/bart-base")

In [None]:
test_dataset = HMDataset("./data", "test", tokenizer)
test_loader = DataLoader(
    test_dataset, batch_size=4, shuffle=False, num_workers=0, pin_memory=True
)

In [None]:
metric = evaluate.load("bertscore")

for batch in tqdm(test_loader):
    input_ids = batch["input_ids"]
    attention_mask = batch["attention_mask"]
    
    input_ids = input_ids.to(device)
    attention_mask = attention_mask.to(device)
    
    items = model.generate(
            input_ids=input_ids,
            attention_mask=attention_mask,
            num_beams=3,
            max_length=128,
            early_stopping=True,
            length_penalty=0.6,
        )

    decoded_items = tokenizer.batch_decode(
            items,
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
            return_tensors="pt",
        )
    
    decoded_labels = tokenizer.batch_decode(
            batch["labels"],
            skip_special_tokens=True,
            clean_up_tokenization_spaces=True,
            return_tensors="pt",
        )
    
    metric.add_batch(predictions=decoded_items, references=decoded_labels)
    
score = metric.compute(lang="en")

In [None]:
precision = round(mean(score["precision"]), 4)
recall = round(mean(score["recall"]), 4)
f1 = round(mean(score["f1"]), 4)

print("----------")
print(f"  BERTScore: precision: {precision} recall: {recall} f1: {f1}")
print("----------\n\n")