In [13]:
%load_ext autoreload
%autoreload 2
import torch
import random
import time
import datetime
import numpy as np

from torch.optim import AdamW
from torch.nn import CrossEntropyLoss
from torch.utils.data import DataLoader, RandomSampler
from data.token_classification import TokenClassificationDataset
from transformers import LayoutLMv3Tokenizer

from src.modeling.docpolarbert.modeling_docpolarbert import DocPolarBERTForTokenClassification
from src.modeling.docpolarbert.train_utils import train_step_token_classification, eval_token_classification

The autoreload extension is already loaded. To reload it, use:
  %reload_ext autoreload


In [17]:
# Load FUNSD dataset
DATA_DIR = 'data/funsd/bio'
pad_token_label_id = CrossEntropyLoss().ignore_index
SEED = 3
NUM_EPOCHS = 30
LEARNING_RATE = 5e-5
BATCH_SIZE = 16
# Set all seeds
random.seed(SEED)
np.random.seed(SEED)
torch.manual_seed(SEED)

g = torch.Generator()
g.manual_seed(SEED)

tokenizer = LayoutLMv3Tokenizer.from_pretrained('microsoft/layoutlmv3-base')
train_dataset = TokenClassificationDataset(
    data_dir=DATA_DIR,
    tokenizer=tokenizer,
    pad_token_label_id=pad_token_label_id,
    mode='train')

val_dataset = TokenClassificationDataset(
    data_dir=DATA_DIR,
    tokenizer=tokenizer,
    pad_token_label_id=pad_token_label_id,
    mode='test')

train_sampler = RandomSampler(train_dataset)

train_sampler = torch.utils.data.RandomSampler(train_dataset, generator=g)

train_dataloader = DataLoader(dataset=train_dataset,
                              batch_size=BATCH_SIZE,
                              sampler=train_sampler)

val_sampler = torch.utils.data.RandomSampler(val_dataset, generator=g)

val_dataloader = DataLoader(dataset=val_dataset,
                            batch_size=BATCH_SIZE,
                            sampler=val_sampler)



In [None]:
DOCPOLARBERT_DIR = 'docpolarbert-base'

model = DocPolarBERTForTokenClassification.from_pretrained(DOCPOLARBERT_DIR, num_labels=len(train_dataset.idx2label))
model.to("cpu" if torch.backends.mps.is_available() else "cuda" if torch.cuda.is_available() else "cpu")
optimizer = AdamW(model.parameters(), lr=LEARNING_RATE)


with torch.no_grad():
    avg_eval_loss, precision, recall, f1 = eval_token_classification(
    model=model,
    eval_dataloader=val_dataloader,
    idx2label=val_dataset.idx2label,
    print_results=True)
    print(f"Epoch [{0}/{NUM_EPOCHS}], Average validation loss: {avg_eval_loss:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}.")


In [19]:
for epoch in range(NUM_EPOCHS):
    start_epoch = time.time()
    print(f'Starting epoch {epoch + 1}/{NUM_EPOCHS}')
    avg_loss  = train_step_token_classification(
            model=model,
            train_dataloader=train_dataloader,
            idx2label=train_dataset.idx2label,
            optimizer=optimizer,
            )
    # ---------------------------------------- Validation ---------------------------------------- #
    model.eval()
    with torch.no_grad():
        avg_eval_loss, precision, recall, f1 = eval_token_classification(
            model=model,
            eval_dataloader=val_dataloader,
            idx2label=val_dataset.idx2label,
            print_results=True)
        print(f"Epoch [{epoch + 1}/{NUM_EPOCHS}], Average training loss: {avg_loss:.4f}, Average validation loss: {avg_eval_loss:.4f}, Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}.")
        print(f"Time taken for epoch {epoch + 1}: {datetime.timedelta(seconds=int(time.time() - start_epoch))}")
    model.train()

Starting epoch 1/30
Epoch [1/30], Average training loss: 1.4694, Average validation loss: 0.9239, Precision: 0.4983, Recall: 0.5302, F1: 0.5138.
Time taken for epoch 1: 0:06:01
Starting epoch 2/30
Epoch [2/30], Average training loss: 0.8953, Average validation loss: 0.8240, Precision: 0.5630, Recall: 0.7191, F1: 0.6316.
Time taken for epoch 2: 0:05:59
Starting epoch 3/30
Epoch [3/30], Average training loss: 0.7078, Average validation loss: 0.7931, Precision: 0.6251, Recall: 0.7628, F1: 0.6871.
Time taken for epoch 3: 0:06:02
Starting epoch 4/30
Epoch [4/30], Average training loss: 0.5676, Average validation loss: 0.7118, Precision: 0.6738, Recall: 0.7395, F1: 0.7051.
Time taken for epoch 4: 0:05:37
Starting epoch 5/30
Epoch [5/30], Average training loss: 0.4875, Average validation loss: 0.7030, Precision: 0.7058, Recall: 0.7847, F1: 0.7431.
Time taken for epoch 5: 0:05:34
Starting epoch 6/30
Epoch [6/30], Average training loss: 0.3804, Average validation loss: 0.7678, Precision: 0.7339

In [20]:
# Run evaluation on the validation set
model.eval()
with torch.no_grad():
    avg_eval_loss, precision, recall, f1 = eval_token_classification(
        model=model,
        eval_dataloader=val_dataloader,
        idx2label=val_dataset.idx2label,
        print_results=True)
print(f"Precision: {precision:.4f}, Recall: {recall:.4f}, F1: {f1:.4f}.")

Precision: 0.7757, Recall: 0.7902, F1: 0.7829.


In [27]:
model.config.id2label = val_dataset.idx2label
model.config.label2id = val_dataset.label2idx
# Save fine-tuned model
model.save_pretrained('models/docpolarbert-funsd')