<a href="https://colab.research.google.com/github/jinhopark8345/FormUnderstanding/blob/main/notebooks/Fine_tuning_bros_spade_on_FUNSD_entity_linking_dataset.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

### Fine-tuning BROS with Spade head on FUNSD entity linking dataset (Entity Linking Task)


I have rewritten fine-tuning code from https://github.com/clovaai/bros in this notebook.

In this noteobok, we are going to fine-tune BROS with Spade head on FUNSD dataset for EL task.

### Set-up Environment
- Install necessary packages
- Import libraries

In [None]:
!pip install git+https://github.com/huggingface/transformers # Bros models are not available on last release of transformers yet (4.33.2)
!pip install pytorch-lightning omegaconf overrides datasets

Collecting git+https://github.com/huggingface/transformers
  Cloning https://github.com/huggingface/transformers to /tmp/pip-req-build-um89dhef
  Running command git clone --filter=blob:none --quiet https://github.com/huggingface/transformers /tmp/pip-req-build-um89dhef
  Resolved https://github.com/huggingface/transformers to commit 382ba670ed2376a9454c3c841fae4819118ec4f5
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
import datetime
import itertools
import json
import math
import os
import random
import re
import time
from copy import deepcopy
from pathlib import Path
from pprint import pprint

import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn as nn
import yaml
from omegaconf import OmegaConf
from omegaconf.dictconfig import DictConfig
from overrides import overrides
from pytorch_lightning.callbacks import (
    EarlyStopping,
    LearningRateMonitor,
    ModelCheckpoint,
    ModelSummary,
    TQDMProgressBar,
)
from pytorch_lightning.loggers.tensorboard import TensorBoardLogger
from pytorch_lightning.plugins import CheckpointIO
from pytorch_lightning.utilities import rank_zero_only
from torch.optim import SGD, Adam, AdamW
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data.dataloader import DataLoader
from torch.utils.data.dataset import Dataset
from tqdm import tqdm
from transformers import (
    AutoTokenizer,
    BrosConfig,
    BrosSpadeELForTokenClassification,
    BrosProcessor,
)

from collections import defaultdict
from datasets import load_dataset, load_from_disk



### Define FUNSD Dataset for EL task


In [None]:
class FUNSDSpadeELDataset(Dataset):
    """FUNSD BIOES tagging Dataset

    FUNSD : Form Understanding in Noisy Scanned Documents

    """

    def __init__(
        self,
        dataset,
        tokenizer,
        max_seq_length=512,
        split="train",
    ):
        self.dataset = dataset
        self.tokenizer = tokenizer
        self.max_seq_length = max_seq_length
        self.split = split

        self.pad_token_id = self.tokenizer.pad_token_id
        self.cls_token_id = self.tokenizer.cls_token_id
        self.sep_token_id = self.tokenizer.sep_token_id
        self.unk_token_id = self.tokenizer.unk_token_id

        self.examples = load_dataset(self.dataset)[split]

        self.class_names = ["other", "header", "question", "answer"]
        self.out_class_name = "other"
        self.class_idx_dic = {
            cls_name: idx for idx, cls_name in enumerate(self.class_names)
        }
        self.pad_token = self.tokenizer.pad_token
        self.ignore_label_id = -100

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, idx):
        sample = self.examples[idx]

        entity_labels = sample["labels"]
        entities = sample["words"]
        linkings = sample["linkings"]
        assert len(entity_labels) == len(entities)

        width, height = sample["img"].size
        cls_bbs = [0] * 4  # bbox for first token
        sep_bbs = [width, height] * 2  # bbox for last token

        # make placeholders
        padded_input_ids = np.ones(self.max_seq_length, dtype=int) * self.pad_token_id
        padded_bboxes = np.zeros((self.max_seq_length, 4), dtype=np.float32)
        attention_mask = np.zeros(self.max_seq_length, dtype=int)
        are_box_first_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
        are_box_end_tokens = np.zeros(self.max_seq_length, dtype=np.bool_)
        itc_labels = np.zeros(self.max_seq_length, dtype=int)
        stc_labels = np.ones(self.max_seq_length, dtype=np.int64) * self.max_seq_length
        el_labels = np.ones(self.max_seq_length, dtype=int) * self.max_seq_length


        # filter empty entities
        entities = [[e for e in l if e['text'].strip() != ""] for l in entities]

        # convert linkings from "entity_idx to entity_idx" to "text_box to text_box"
        from_text_box2to_text_box_set = defaultdict(set)
        for linking in linkings:
            if not linking:
                continue

            for link in linking:
                from_entity_idx, to_entity_idx = link

                # discard if either of the entity is empty
                if len(entities[from_entity_idx]) == 0 or len(entities[to_entity_idx]) == 0:
                    continue

                if entity_labels[from_entity_idx] == "other" or entity_labels[to_entity_idx] == "other":
                    continue

                from_text_box, to_text_box = entities[from_entity_idx][0], entities[to_entity_idx][0]
                from_text_box = tuple([from_text_box["text"], tuple(from_text_box["box"])])
                to_text_box = tuple([to_text_box["text"], tuple(to_text_box["box"])])
                from_text_box2to_text_box_set[from_text_box].add(to_text_box)

        """
        need to continue the links that are connected to "other" label

        """

        """
        in the beginning, entities are like below

            [
                [{'box': [147, 148, 213, 168], 'text': 'Attorney'},
                    {'box': [216, 151, 275, 168], 'text': 'General'},
                    {'box': [148, 172, 187, 190], 'text': 'Betty'},
                    {'box': [191, 169, 206, 187], 'text': 'D.'},
                    {'box': [211, 170, 305, 191], 'text': 'Montgomery'}],
                [{'box': [275, 249, 377, 267], 'text': 'CONFIDENTIAL'},
                    {'box': [380, 250, 457, 267], 'text': 'FACSIMILE'},
                    {'box': [264, 267, 369, 281], 'text': 'TRANSMISSION'},
                    {'box': [369, 267, 422, 281], 'text': 'COVER'},
                    {'box': [420, 267, 467, 281], 'text': 'SHEET'}],
                [{'box': [352, 297, 383, 314], 'text': '(614)'},
                    {'box': [384, 296, 405, 313], 'text': '466-'},
                    {'box': [406, 297, 438, 312], 'text': '5087'}]
                ...
            ]

        and "entities" and "entity_labels" are synchronized based on their indices

        1. filter out "text_box" with emtpy text
        2. convert entities into input_ids and bboxes
            2-1. convert entity into "list of text_box"
                2-1-1. convert "text_box" into "list of tokens"

        in result we will have,

            - input_ids & bboxes are synchronized based on their indices

            - text_box_idx2token_indices : List[List[int]]:
                    -> text_box_idx to token_indices (of corressponding text)

            - label2text_box_indices_list : Dict[str, List[List[int]]]
                    -> list of text_box_indices belong to each label (class_name) mapping

            - text_box2text_box_idx : Dict[tuple, int]
                    -> tuple value of text_box to text_box_idx mapping,
                       going to use with "from_text_box2to_text_box" (came from converting linking gt)
                       to get linkings between text_box_indices

        """

        # 1. filter out "text_box" with emtpy text
        entity_and_label_list = []
        for entity, label in zip(entities, entity_labels):
            cur_entity_and_label = []
            for e in entity:
                if e["text"].strip() != "":
                    cur_entity_and_label.append(e)
            if cur_entity_and_label:
                entity_and_label_list.append((cur_entity_and_label, label))


        # 2. convert entities into input_ids and bboxes
        text_box_idx = 0
        cum_token_idx = 0
        input_ids = []
        bboxes = []
        text_box_idx2token_indices = []
        label2text_box_indices_list = {cls_name: [] for cls_name in self.class_names}
        text_box2text_box_idx = {}
        for entity_idx, (entity, label) in enumerate(entity_and_label_list):
            text_box_indices = []

            # 2-1. convert entity into "list of text_box"
            for text_and_box in entity:
                text_box_indices.append(text_box_idx)

                text, box = text_and_box["text"], text_and_box["box"]
                text_box2text_box_idx[tuple([text, tuple(box)])] = text_box_idx
                this_text_box_token_indices = []

                if text.strip() == "":
                    continue

                # 2-1-1. convert "text_box" into "list of tokens"
                this_input_ids = self.tokenizer.encode(text, add_special_tokens=False)
                input_ids += this_input_ids
                this_bboxes = [box for _ in range(len(this_input_ids))]
                bboxes += this_bboxes

                for _ in this_input_ids:
                    cum_token_idx += 1
                    this_text_box_token_indices.append(cum_token_idx)

                text_box_idx2token_indices.append(this_text_box_token_indices)
                text_box_idx += 1

            label2text_box_indices_list[label].append(text_box_indices)


        # convert linkings from "text_box to list of text_box" to "text_box idx to text_box idx"
        from_text_box_idx2to_text_box_idx = []
        for from_text_box, to_text_box_set in from_text_box2to_text_box_set.items():
            for to_text_box in to_text_box_set:
                from_text_box_idx2to_text_box_idx.append(
                    (text_box2text_box_idx[from_text_box], text_box2text_box_idx[to_text_box])
                )

        tokens_length_list: List[int] = [len(l) for l in text_box_idx2token_indices]
        # consider [CLS] token that will be added to input_ids, shift "end token indices" 1 to the right
        et_indices = np.array(list(itertools.accumulate(tokens_length_list))) + 1

        # since we subtract original length from shifted indices, "start token indices" are aligned as well
        st_indices = et_indices - np.array(tokens_length_list)

        # last index will be used for [SEP] token
        # to make sure st_indices and end_indices are paired, in case st_indices are cut by max_sequence length,
        st_indices = st_indices[st_indices < self.max_seq_length - 1]
        et_indices = et_indices[et_indices < self.max_seq_length - 1]

        # to make sure st_indices and end_indices are paired, in case st_indices are cut by max_sequence length,
        min_len = min(len(st_indices), len(et_indices))
        st_indices = st_indices[:min_len]
        et_indices = et_indices[:min_len]
        assert len(st_indices) == len(et_indices)

        are_box_first_tokens[st_indices] = True
        are_box_end_tokens[et_indices] = True

        from_text_box_idx2to_text_box_idx = sorted(from_text_box_idx2to_text_box_idx, key=lambda e: (e[0], e[1]))
        for from_idx, to_idx in from_text_box_idx2to_text_box_idx:

            if from_idx >= len(text_box_idx2token_indices) or to_idx >= len(text_box_idx2token_indices):
                continue

            if (
                text_box_idx2token_indices[from_idx][0] >= self.max_seq_length
                or text_box_idx2token_indices[to_idx][0] >= self.max_seq_length
            ):
                continue

            from_token_idx = text_box_idx2token_indices[from_idx][0]
            to_token_idx = text_box_idx2token_indices[to_idx][0]
            el_labels[to_token_idx] = from_token_idx


        # For [CLS] and [SEP]
        input_ids = (
            [self.cls_token_id]
            + input_ids[: self.max_seq_length - 2]
            + [self.sep_token_id]
        )
        if len(bboxes) == 0:
            # When len(json_obj["words"]) == 0 (no OCR result)
            bboxes = [cls_bbs] + [sep_bbs]
        else:  # len(list_bbs) > 0
            bboxes = [cls_bbs] + bboxes[: self.max_seq_length - 2] + [sep_bbs]
        bboxes = np.array(bboxes)

        # update ppadded input_ids, labels, bboxes
        len_ori_input_ids = len(input_ids)
        padded_input_ids[:len_ori_input_ids] = input_ids
        # padded_labels[:len_ori_input_ids] = np.array(labels)
        attention_mask[:len_ori_input_ids] = 1
        padded_bboxes[:len_ori_input_ids, :] = bboxes


        # Normalize bbox -> 0 ~ 1
        padded_bboxes[:, [0, 2]] = padded_bboxes[:, [0, 2]] / width
        padded_bboxes[:, [1, 3]] = padded_bboxes[:, [1, 3]] / height
        # padded_bboxes = padded_bboxes[:, [0, 1, 2, 1, 2, 3, 0, 3]]
        # padded_bboxes[:, [0, 2, 4, 6]] = padded_bboxes[:, [0, 2, 4, 6]] / width
        # padded_bboxes[:, [1, 3, 5, 7]] = padded_bboxes[:, [1, 3, 5, 7]] / height

        # convert to tensor
        padded_input_ids = torch.from_numpy(padded_input_ids)
        padded_bboxes = torch.from_numpy(padded_bboxes)
        attention_mask = torch.from_numpy(attention_mask)
        are_box_first_tokens = torch.from_numpy(are_box_first_tokens)
        are_box_end_tokens = torch.from_numpy(are_box_end_tokens)
        itc_labels = torch.from_numpy(itc_labels)
        stc_labels = torch.from_numpy(stc_labels)
        el_labels = torch.from_numpy(el_labels)

        return_dict = {
            "filename": sample["filename"],
            "input_ids": padded_input_ids,
            "bbox": padded_bboxes,
            "attention_mask": attention_mask,
            "are_box_first_tokens": are_box_first_tokens,
            "el_labels": el_labels,
            "itc_labels": itc_labels,
            "stc_labels": stc_labels,
        }

        return return_dict



### Define PL Data Module

In [None]:
class BROSDataPLModule(pl.LightningDataModule):
    def __init__(self, cfg):
        super().__init__()
        self.cfg = cfg
        self.train_batch_size = self.cfg.train.batch_size
        self.val_batch_size = self.cfg.val.batch_size
        self.train_dataset = None
        self.val_dataset = None

    def train_dataloader(self):
        loader = DataLoader(
            dataset=self.train_dataset,
            batch_size=self.train_batch_size,
            num_workers=self.cfg.train.num_workers,
            pin_memory=True,
            shuffle=True,
        )

        return loader

    def val_dataloader(self):
        loader = DataLoader(
            dataset=self.val_dataset,
            batch_size=self.val_batch_size,
            num_workers=self.cfg.val.num_workers,
            shuffle=False,
            pin_memory=True,
            drop_last=False,
        )

        return loader

    @overrides
    def transfer_batch_to_device(self, batch, device, dataloader_idx):
        for k in batch.keys():
            if isinstance(batch[k], torch.Tensor):
                batch[k] = batch[k].to(device)
        return batch


### Define PL Model Module

In [None]:
def eval_el_spade_batch(
    pr_el_labels,
    gt_el_labels,
    are_box_first_tokens,
    dummy_idx,
):
    n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = 0, 0, 0

    bsz = pr_el_labels.shape[0]
    for example_idx in range(bsz):
        n_gt_rel, n_pr_rel, n_correct_rel = eval_el_spade_example(
            pr_el_labels[example_idx],
            gt_el_labels[example_idx],
            are_box_first_tokens[example_idx],
            dummy_idx,
        )

        n_batch_gt_rel += n_gt_rel
        n_batch_pr_rel += n_pr_rel
        n_batch_correct_rel += n_correct_rel

    return n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel


def eval_el_spade_example(pr_el_label, gt_el_label, box_first_token_mask, dummy_idx):
    gt_relations = parse_relations(gt_el_label, box_first_token_mask, dummy_idx)
    pr_relations = parse_relations(pr_el_label, box_first_token_mask, dummy_idx)

    n_gt_rel = len(gt_relations)
    n_pr_rel = len(pr_relations)
    n_correct_rel = len(gt_relations & pr_relations)

    return n_gt_rel, n_pr_rel, n_correct_rel

def parse_relations(el_label, box_first_token_mask, dummy_idx):
    valid_el_labels = el_label * box_first_token_mask
    valid_el_labels = valid_el_labels.cpu().numpy()
    el_label_np = el_label.cpu().numpy()

    valid_token_indices = np.where(
        ((valid_el_labels != dummy_idx) * (valid_el_labels != 0))
    )
    link_map_tuples = []
    for token_idx in valid_token_indices[0]:
        link_map_tuples.append((el_label_np[token_idx], token_idx))

    return set(link_map_tuples)

def linear_scheduler(optimizer, warmup_steps, training_steps, last_epoch=-1):
    """linear_scheduler with warmup from huggingface"""

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return float(current_step) / float(max(1, warmup_steps))
        return max(
            0.0,
            float(training_steps - current_step)
            / float(max(1, training_steps - warmup_steps)),
        )

    return LambdaLR(optimizer, lr_lambda, last_epoch)


def cosine_scheduler(
    optimizer, warmup_steps, training_steps, cycles=0.5, last_epoch=-1
):
    """Cosine LR scheduler with warmup from huggingface"""

    def lr_lambda(current_step):
        if current_step < warmup_steps:
            return current_step / max(1, warmup_steps)
        progress = current_step - warmup_steps
        progress /= max(1, training_steps - warmup_steps)
        return max(0.0, 0.5 * (1.0 + math.cos(math.pi * cycles * 2 * progress)))

    return LambdaLR(optimizer, lr_lambda, last_epoch)


def multistep_scheduler(optimizer, warmup_steps, milestones, gamma=0.1, last_epoch=-1):
    def lr_lambda(current_step):
        if current_step < warmup_steps:
            # calculate a warmup ratio
            return current_step / max(1, warmup_steps)
        else:
            # calculate a multistep lr scaling ratio
            idx = np.searchsorted(milestones, current_step)
            return gamma**idx

    return LambdaLR(optimizer, lr_lambda, last_epoch)


class BROSModelPLModule(pl.LightningModule):
    def __init__(self, cfg, tokenizer):
        super().__init__()
        self.cfg = cfg
        self.model = None
        self.optimizer_types = {
            "sgd": SGD,
            "adam": Adam,
            "adamw": AdamW,
        }
        self.loss_func = nn.CrossEntropyLoss()
        self.class_names = None
        self.tokenizer = tokenizer
        self.dummy_idx = None

        self.validation_step_outputs = []

    def training_step(self, batch, batch_idx, *args):
        # unpack batch
        input_ids = batch["input_ids"]
        bbox = batch["bbox"]
        attention_mask = batch["attention_mask"]
        are_box_first_tokens = batch["are_box_first_tokens"]
        labels = batch["el_labels"]

        # inference model
        prediction = self.model(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            bbox_first_token_mask=are_box_first_tokens,
            labels=labels,
        )

        loss = prediction.loss
        self.log_dict({"train_loss": loss}, sync_dist=True, prog_bar=True)

        return loss

    def validation_step(self, batch, batch_idx, *args):
        # unpack batch
        input_ids = batch["input_ids"]
        bbox = batch["bbox"]
        attention_mask = batch["attention_mask"]
        are_box_first_tokens = batch["are_box_first_tokens"]
        labels = batch["el_labels"]

        # inference model
        prediction = self.model(
            input_ids=input_ids,
            bbox=bbox,
            attention_mask=attention_mask,
            bbox_first_token_mask=are_box_first_tokens,
            labels=labels,
        )

        self.log_dict({"val_loss": prediction.loss}, sync_dist=True, prog_bar=True)

        pr_el_labels = torch.argmax(prediction.logits, -1)
        n_batch_gt_rel, n_batch_pr_rel, n_batch_correct_rel = eval_el_spade_batch(
            pr_el_labels,
            labels,
            are_box_first_tokens,
            self.dummy_idx,
        )

        step_out = {
            "loss": prediction.loss,
            "n_batch_gt_rel": n_batch_gt_rel,
            "n_batch_pr_rel": n_batch_pr_rel,
            "n_batch_correct_rel": n_batch_correct_rel,
        }

        self.validation_step_outputs.append(step_out)

        return prediction.loss

    def on_validation_epoch_end(self):

        all_preds = self.validation_step_outputs
        n_total_gt_rel, n_total_pred_rel, n_total_correct_rel = 0, 0, 0

        for step_out in all_preds:
            n_total_gt_rel += step_out["n_batch_gt_rel"]
            n_total_pred_rel += step_out["n_batch_pr_rel"]
            n_total_correct_rel += step_out["n_batch_correct_rel"]

        precision = 0.0 if n_total_pred_rel == 0 else n_total_correct_rel / n_total_pred_rel
        recall = 0.0 if n_total_gt_rel == 0 else n_total_correct_rel / n_total_gt_rel
        f1 = (
            0.0
            if recall * precision == 0
            else 2.0 * recall * precision / (recall + precision)
        )

        self.log_dict(
            {
                "precision": precision,
                "recall": recall,
                "f1": f1,
            },
            sync_dist=True,
        )

    def configure_optimizers(self):
        optimizer = self._get_optimizer()
        scheduler = self._get_lr_scheduler(optimizer)
        scheduler = {
            "scheduler": scheduler,
            "name": "learning_rate",
            "interval": "step",
        }
        return [optimizer], [scheduler]

    def _get_optimizer(self):
        opt_cfg = self.cfg.train.optimizer
        method = opt_cfg.method.lower()

        if method not in self.optimizer_types:
            raise ValueError(f"Unknown optimizer method={method}")

        kwargs = dict(opt_cfg.params)
        kwargs["params"] = self.model.parameters()
        optimizer = self.optimizer_types[method](**kwargs)

        return optimizer

    def _get_lr_scheduler(self, optimizer):
        cfg_train = self.cfg.train
        lr_schedule_method = cfg_train.optimizer.lr_schedule.method
        lr_schedule_params = cfg_train.optimizer.lr_schedule.params

        if lr_schedule_method is None:
            scheduler = LambdaLR(optimizer, lr_lambda=lambda _: 1)
        elif lr_schedule_method == "step":
            scheduler = multistep_scheduler(optimizer, **lr_schedule_params)
        elif lr_schedule_method == "cosine":
            total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
            total_batch_size = cfg_train.batch_size * self.trainer.world_size
            max_iter = total_samples / total_batch_size
            scheduler = cosine_scheduler(
                optimizer, training_steps=max_iter, **lr_schedule_params
            )
        elif lr_schedule_method == "linear":
            total_samples = cfg_train.max_epochs * cfg_train.num_samples_per_epoch
            total_batch_size = cfg_train.batch_size * self.trainer.world_size
            max_iter = total_samples / total_batch_size
            scheduler = linear_scheduler(
                optimizer, training_steps=max_iter, **lr_schedule_params
            )
        else:
            raise ValueError(f"Unknown lr_schedule_method={lr_schedule_method}")

        return scheduler

    @rank_zero_only
    def on_save_checkpoint(self, checkpoint):
        save_path = Path(self.cfg.workspace) / self.cfg.exp_name / self.cfg.exp_version
        model_save_path = (
            Path(self.cfg.workspace)
            / self.cfg.exp_name
            / self.cfg.exp_version
            / "huggingface_model"
        )
        tokenizer_save_path = (
            Path(self.cfg.workspace)
            / self.cfg.exp_name
            / self.cfg.exp_version
            / "huggingface_tokenizer"
        )
        self.model.save_pretrained(model_save_path)
        self.tokenizer.save_pretrained(tokenizer_save_path)


### Define train function

In [None]:
def train(cfg):
    cfg.save_weight_dir = os.path.join(cfg.workspace, "checkpoints")
    cfg.tensorboard_dir = os.path.join(cfg.workspace, "tensorboard_logs")
    cfg.exp_version = datetime.datetime.now().strftime("%Y%m%d_%H%M%S")

    # pprint cfg
    print(OmegaConf.to_yaml(cfg))

    # set env
    os.environ["TOKENIZERS_PARALLELISM"] = "false"  # prevent deadlock with tokenizer
    pl.seed_everything(cfg.seed)

    # Load Tokenizer (going to be used in dataset to to convert texts to input_ids)
    tokenizer = BrosProcessor.from_pretrained(cfg.tokenizer_path).tokenizer

    # prepare FUNSD dataset
    train_dataset = FUNSDSpadeELDataset(
        dataset=cfg.dataset,
        tokenizer=tokenizer,
        max_seq_length=cfg.model.max_seq_length,
        split="train",
    )

    val_dataset = FUNSDSpadeELDataset(
        dataset=cfg.dataset,
        tokenizer=tokenizer,
        max_seq_length=cfg.model.max_seq_length,
        split="test",
    )

    # make data module & update data_module train and val dataset
    data_module = BROSDataPLModule(cfg)
    data_module.train_dataset = train_dataset
    data_module.val_dataset = val_dataset

    # Load BROS config & pretrained model
    ## update config
    bros_config = BrosConfig.from_pretrained(cfg.model.pretrained_model_name_or_path)

    bros_config.id2label = {
        i: label for i, label in enumerate(train_dataset.class_names)
    }
    bros_config.label2id = {
        label: i for i, label in enumerate(train_dataset.class_names)
    }
    bros_config.num_labels = len(train_dataset.class_names)

    ## load pretrained model
    bros_model = BrosSpadeELForTokenClassification.from_pretrained(
        cfg.model.pretrained_model_name_or_path, config=bros_config
    )

    # model module setting
    model_module = BROSModelPLModule(cfg, tokenizer=tokenizer)
    model_module.model = bros_model
    model_module.class_names = train_dataset.class_names
    model_module.dummy_idx = cfg.model.max_seq_length

    # model_module.bioes_class_names = train_dataset.bioes_class_names

    # define trainer logger, callbacks
    loggers = TensorBoardLogger(
        save_dir=cfg.workspace,
        name=cfg.exp_name,
        version=cfg.exp_version,
        default_hp_metric=False,
    )
    lr_callback = LearningRateMonitor(logging_interval="step")

    checkpoint_callback = ModelCheckpoint(
        dirpath=Path(cfg.workspace) / cfg.exp_name / cfg.exp_version / "checkpoints",
        filename="bros-funsd-{epoch:02d}-{val_loss:.2f}",
        monitor="val_loss",
        mode="min",
        save_top_k=1,  # if you save more than 1 model,
        # then checkpoint and huggingface model are not guaranteed to be matching
        # because we are saving with huggingface model with save_pretrained method
        # in "on_save_checkpoint" method in "BROSModelPLModule"
    )

    model_summary_callback = ModelSummary(max_depth=5)
    early_stop_callback = EarlyStopping(
        monitor="f1", min_delta=0.00, patience=5, verbose=True, mode="max"
    )

    # define Trainer and start training
    trainer = pl.Trainer(
        accelerator=cfg.train.accelerator,
        num_nodes=cfg.get("num_nodes", 1),
        precision=16 if cfg.train.use_fp16 else 32,
        logger=loggers,
        callbacks=[
            lr_callback,
            checkpoint_callback,
            model_summary_callback,
            early_stop_callback,
        ],
        max_epochs=cfg.train.max_epochs,
        num_sanity_val_steps=3,
        gradient_clip_val=cfg.train.clip_gradient_value,
        gradient_clip_algorithm=cfg.train.clip_gradient_algorithm,
        log_every_n_steps=1,
    )

    trainer.fit(model_module, data_module, ckpt_path=cfg.train.get("ckpt_path", None))



### Set hyperparameters and run train function

In [None]:
# load training config
finetune_funsd_ee_bioes_config = {
    "workspace": "./finetune_funsd_el_spade",
    "exp_name": "finetune_funsd_el_spade__bros-base-uncased",
    "tokenizer_path": "naver-clova-ocr/bros-base-uncased",
    "dataset": "jinho8345/funsd",
    "task": "el",
    "seed": 1,
    "cudnn_deterministic": False,
    "cudnn_benchmark": True,
    "model": {
        "pretrained_model_name_or_path": "jinho8345/bros-base-uncased",
        "max_seq_length": 512,
    },
    "train": {
        "ckpt_path": None,  # or None
        "batch_size": 8,
        "num_samples_per_epoch": 149,
        "max_epochs": 200,
        "use_fp16": True,
        "accelerator": "gpu",
        "clip_gradient_algorithm": "norm",
        "clip_gradient_value": 1.0,
        "num_workers": 4,
        "optimizer": {
            "method": "adamw",
            "params": {"lr": 5e-05},
            "lr_schedule": {"method": "linear", "params": {"warmup_steps": 0}},
        },
        "val_interval": 1,
    },
    "val": {"batch_size": 8, "num_workers": 8, "limit_val_batches": 1.0},
}

# convert dictionary to omegaconf and update config
cfg = OmegaConf.create(finetune_funsd_ee_bioes_config)
train(cfg)


INFO:lightning_fabric.utilities.seed:Global seed set to 1


workspace: ./finetune_funsd_el_spade
exp_name: finetune_funsd_el_spade__bros-base-uncased
tokenizer_path: naver-clova-ocr/bros-base-uncased
dataset: jinho8345/funsd
task: el
seed: 1
cudnn_deterministic: false
cudnn_benchmark: true
model:
  pretrained_model_name_or_path: jinho8345/bros-base-uncased
  max_seq_length: 512
train:
  ckpt_path: null
  batch_size: 8
  num_samples_per_epoch: 149
  max_epochs: 200
  use_fp16: true
  accelerator: gpu
  clip_gradient_algorithm: norm
  clip_gradient_value: 1.0
  num_workers: 4
  optimizer:
    method: adamw
    params:
      lr: 5.0e-05
    lr_schedule:
      method: linear
      params:
        warmup_steps: 0
  val_interval: 1
val:
  batch_size: 8
  num_workers: 8
  limit_val_batches: 1.0
save_weight_dir: ./finetune_funsd_el_spade/checkpoints
tensorboard_dir: ./finetune_funsd_el_spade/tensorboard_logs
exp_version: '20230920_060153'



Some weights of BrosSpadeELForTokenClassification were not initialized from the model checkpoint at jinho8345/bros-base-uncased and are newly initialized: ['entity_linker.key.bias', 'entity_linker.key.weight', 'entity_linker.query.bias', 'entity_linker.query.weight', 'entity_linker.dummy_node']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
INFO:pytorch_lightning.utilities.rank_zero:Using 16bit Automatic Mixed Precision (AMP)
INFO:pytorch_lightning.utilities.rank_zero:Trainer already configured with model summary callbacks: [<class 'pytorch_lightning.callbacks.model_summary.ModelSummary'>]. Skipping setting a default `ModelSummary` callback.
INFO:pytorch_lightning.utilities.rank_zero:GPU available: True (cuda), used: True
INFO:pytorch_lightning.utilities.rank_zero:TPU available: False, using: 0 TPU cores
INFO:pytorch_lightning.utilities.rank_zero:IPU available: False, using: 0 IPUs
INFO:pytorch_lightning.utilities.rank_zer

Sanity Checking: 0it [00:00, ?it/s]

Training: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved. New best score: 0.001


Validation: 0it [00:00, ?it/s]

Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.040 >= min_delta = 0.0. New best score: 0.041


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.056 >= min_delta = 0.0. New best score: 0.097


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.059 >= min_delta = 0.0. New best score: 0.156


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.058 >= min_delta = 0.0. New best score: 0.214


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.049 >= min_delta = 0.0. New best score: 0.263


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.041 >= min_delta = 0.0. New best score: 0.304


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.037 >= min_delta = 0.0. New best score: 0.342


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.030 >= min_delta = 0.0. New best score: 0.371


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.027 >= min_delta = 0.0. New best score: 0.398


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.019 >= min_delta = 0.0. New best score: 0.417


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.019 >= min_delta = 0.0. New best score: 0.437


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.018 >= min_delta = 0.0. New best score: 0.455


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.015 >= min_delta = 0.0. New best score: 0.470


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.014 >= min_delta = 0.0. New best score: 0.485


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.013 >= min_delta = 0.0. New best score: 0.498


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.011 >= min_delta = 0.0. New best score: 0.509


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.011 >= min_delta = 0.0. New best score: 0.520


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.010 >= min_delta = 0.0. New best score: 0.530


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.009 >= min_delta = 0.0. New best score: 0.539


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.007 >= min_delta = 0.0. New best score: 0.546


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.007 >= min_delta = 0.0. New best score: 0.554


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.006 >= min_delta = 0.0. New best score: 0.560


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.006 >= min_delta = 0.0. New best score: 0.566


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.006 >= min_delta = 0.0. New best score: 0.572


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.006 >= min_delta = 0.0. New best score: 0.578


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.005 >= min_delta = 0.0. New best score: 0.582


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.005 >= min_delta = 0.0. New best score: 0.587


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.005 >= min_delta = 0.0. New best score: 0.592


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.004 >= min_delta = 0.0. New best score: 0.596


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.004 >= min_delta = 0.0. New best score: 0.600


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.004 >= min_delta = 0.0. New best score: 0.603


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.607


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.610


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.613


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.616


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.619


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.622


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.625


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.003 >= min_delta = 0.0. New best score: 0.627


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.629


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.631


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.634


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.636


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.638


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.640


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.642


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.644


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.645


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.647


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.649


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.650


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.002 >= min_delta = 0.0. New best score: 0.652


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.653


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.654


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.655


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.657


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.658


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.659


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.660


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.661


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.662


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.663


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.664


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.665


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.665


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.666


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.667


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.668


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.669


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.669


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.670


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.670


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.671


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.671


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.672


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.673


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.673


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.673


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.674


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.674


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.675


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.675


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.676


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.676


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.677


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.677


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.678


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.678


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.679


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.679


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.680


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.680


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.681


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.681


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.682


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.682


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.001 >= min_delta = 0.0. New best score: 0.683


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.683


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.684


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.684


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.685


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.685


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.685


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.686


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.686


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.687


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.687


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.687


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.687


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.688


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.688


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.689


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.689


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.689


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.690


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.690


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.690


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.691


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.691


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.691


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.692


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.692


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.692


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.693


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.693


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.693


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.694


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.694


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.694


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.695


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.695


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.695


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.695


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.696


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.696


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.696


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.696


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.697


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.697


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.697


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.697


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.697


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.697


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.698


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.699


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.700


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.701


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.702


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703


Validation: 0it [00:00, ?it/s]

INFO:pytorch_lightning.callbacks.early_stopping:Metric f1 improved by 0.000 >= min_delta = 0.0. New best score: 0.703
INFO:pytorch_lightning.utilities.rank_zero:`Trainer.fit` stopped: `max_epochs=200` reached.
