In [131]:
%load_ext autoreload
%autoreload 2
import torch
import random
import numpy as np
import os

from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader
from transformers import LayoutLMv3Tokenizer

from data.token_classification import TokenClassificationDataset
from src.modeling.docpolarbert.modeling_docpolarbert import DocPolarBERTForTokenClassification
from src.modeling.docpolarbert.train_utils import eval_token_classification

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [132]:
DATA_DIRS = {
    'funsd': '../data/funsd',
    'sroie': '../data/sroie',
    'cord': '../data/cord',
     'payslips': '../data/payslips',
     'docile':  '../data/docile',
}

MODEL_DIRS = '../models/'

In [133]:
for DATASET, DATA_DIR in DATA_DIRS.items():
    pad_token_label_id = CrossEntropyLoss().ignore_index
    SEED = 1
    BATCH_SIZE = 1
    print(f"Evaluating {DATASET}")
    # Set all seeds
    random.seed(SEED)
    np.random.seed(SEED)
    torch.manual_seed(SEED)

    tokenizer = LayoutLMv3Tokenizer.from_pretrained('microsoft/layoutlmv3-base')

    test_dataset = TokenClassificationDataset(
        data_dir=DATA_DIR,
        tokenizer=tokenizer,
        pad_token_label_id=pad_token_label_id,
        mode='test')

    test_sampler = torch.utils.data.SequentialSampler(test_dataset)

    test_dataloader = DataLoader(dataset=test_dataset,
                                batch_size=BATCH_SIZE,
                                sampler=test_sampler,
                                num_workers=0,
                                pin_memory=True
                                 )


    model = DocPolarBERTForTokenClassification.from_pretrained(os.path.join(MODEL_DIRS, f"docpolarbert-{DATASET}"), num_labels=len(test_dataset.idx2label))

    model.to("mps" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")

    model.eval()
    with torch.no_grad():
        avg_eval_loss, precision, recall, f1 = eval_token_classification(
        model=model,
        eval_dataloader=test_dataloader,
        idx2label=test_dataset.idx2label,
        print_results=True)
    print(f"{DATASET} - Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}.")



Evaluating funsd
funsd - Precision: 0.7735, Recall: 0.7892, F1: 0.7813.
Evaluating sroie
sroie - Precision: 0.9641, Recall: 0.9727, F1: 0.9684.
Evaluating cord
cord - Precision: 0.9533, Recall: 0.9621, F1: 0.9577.
Evaluating payslips
payslips - Precision: 0.7675, Recall: 0.8146, F1: 0.7904.
Evaluating docile
docile - Precision: 0.6072, Recall: 0.6255, F1: 0.6162.
