<a href="https://colab.research.google.com/github/deviincture/layoutlmv3/blob/main/OCR_XML_algorithm.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

In [None]:
!pip install transformers@git+https://github.com/monuminu/transformers.git &> /dev/null
!pip install seqeval &> /dev/null

In [None]:
from google.colab import drive
drive.mount('/content/drive')

Mounted at /content/drive


In [None]:
!pip install pyyaml>=5.1, bs4

In [None]:
!pip install beautifulsoup4 &> /dev/null

In [None]:
!pip install lxml  &> /dev/null

In [None]:
!pip install PyPDF2 &> /dev/null

In [None]:
import numpy as np
import os
import numpy as np
import pandas as pd
import torch
from transformers import LayoutLMv2Tokenizer, LayoutLMv2ForTokenClassification, LayoutLMv2Config
from torchvision.transforms import ToTensor
from torch.utils.data import Dataset, DataLoader
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
import warnings
warnings.filterwarnings('ignore')
from PIL import Image



def normalize_box(box, width, height):
    width = int(width)
    height = int(height)
    return [
         int(1000 * (box[0] / width)),
         int(1000 * (box[1] / height)),
         int(1000 * (box[2] / width)),
         int(1000 * (box[3] / height)),
     ]

def resize_and_align_bounding_box(bbox, original_image, target_size):
    x_, y_ = original_image.size
    x_scale = target_size / x_
    y_scale = target_size / y_
    origLeft, origTop, origRight, origBottom = tuple(bbox)
    x = int(np.round(origLeft * x_scale))
    y = int(np.round(origTop * y_scale))
    xmax = int(np.round(origRight * x_scale))
    ymax = int(np.round(origBottom * y_scale))
    return [x-0.5, y-0.5, xmax+0.5, ymax+0.5]

class InvoiceDataSet(Dataset):
    """LayoutLM dataset with visual features."""

    def __init__(self, df, tokenizer, max_length, target_size, train=True):
        self.df = df
        self.tokenizer = tokenizer
        self.max_seq_length = max_length
        self.target_size = target_size
        self.pad_token_box = [0, 0, 0, 0]
        self.train = train

    def __len__(self):
        return self.df.shape[0]

    def __getitem__(self, idx):
        item = self.df.iloc[idx,:].to_dict()
        #base_path = data_config.base_image_path
        original_image = Image.open(os.path.join(base_path , item["imageFilename"])).convert("RGB")
        # resize to target size (to be provided to the pre-trained backbone)
        resized_image = original_image.resize((self.target_size, self.target_size))
        # first, read in annotations at word-level (words, bounding boxes, labels)
        words = item["words"]
        unnormalized_word_boxes = item["bbox"]
        word_labels = item["label"]
        width = item["imageWidth"]
        height = item["imageHeight"]
        normalized_word_boxes = [normalize_box(bbox, width, height) for bbox in unnormalized_word_boxes]
        assert len(words) == len(normalized_word_boxes)

        # next, transform to token-level (input_ids, attention_mask, token_type_ids, bbox, labels)
        token_boxes = []
        unnormalized_token_boxes = []
        token_labels = []
        for word, unnormalized_box, box, label in zip(words, unnormalized_word_boxes, normalized_word_boxes, word_labels):
            word_tokens = self.tokenizer.tokenize(word)
            unnormalized_token_boxes.extend(unnormalized_box for _ in range(len(word_tokens)))
            token_boxes.extend(box for _ in range(len(word_tokens)))
            # label first token as B-label (beginning), label all remaining tokens as I-label (inside)
            for i in range(len(word_tokens)):
                if i == 0:
                    token_labels.extend(['B-' + label])
                else:
                    token_labels.extend(['I-' + label])

        # Truncation of token_boxes + token_labels
        special_tokens_count = 2
        if len(token_boxes) > self.max_seq_length - special_tokens_count:
            token_boxes = token_boxes[: (self.max_seq_length - special_tokens_count)]
            unnormalized_token_boxes = unnormalized_token_boxes[: (self.max_seq_length - special_tokens_count)]
            token_labels = token_labels[: (self.max_seq_length - special_tokens_count)]

        # add bounding boxes and labels of cls + sep tokens
        token_boxes = [self.pad_token_box] + token_boxes + [[1000, 1000, 1000, 1000]]
        unnormalized_token_boxes = [self.pad_token_box] + unnormalized_token_boxes + [[1000, 1000, 1000, 1000]]
        token_labels = [-100] + token_labels + [-100]

        encoding = self.tokenizer(' '.join(words), padding='max_length', truncation=True)
        # Padding of token_boxes up the bounding boxes to the sequence length.
        input_ids = self.tokenizer(' '.join(words), truncation=True)["input_ids"]
        padding_length = self.max_seq_length - len(input_ids)
        token_boxes += [self.pad_token_box] * padding_length
        unnormalized_token_boxes += [self.pad_token_box] * padding_length
        token_labels += [-100] * padding_length
        encoding['bbox'] = token_boxes
        encoding['labels'] = token_labels

        assert len(encoding['input_ids']) == self.max_seq_length
        assert len(encoding['attention_mask']) == self.max_seq_length
        assert len(encoding['token_type_ids']) == self.max_seq_length
        assert len(encoding['bbox']) == self.max_seq_length
        assert len(encoding['labels']) == self.max_seq_length

        encoding['resized_image'] = ToTensor()(resized_image)
        # rescale and align the bounding boxes to match the resized image size (typically 224x224)
        encoding['resized_and_aligned_bounding_boxes'] = [resize_and_align_bounding_box(bbox, original_image, self.target_size) for bbox in unnormalized_token_boxes]
        #encoding['unnormalized_token_boxes'] = unnormalized_token_boxes

        # finally, convert everything to PyTorch tensors
        for k,v in encoding.items():
            if k == 'labels':
                label_indices = []
                # convert labels from string to indices
                for label in encoding[k]:
                    if label != -100:
                        label_indices.append(data_config.label2id[label])
                    else:
                        label_indices.append(label)
                encoding[k] = label_indices
            encoding[k] = torch.as_tensor(encoding[k])
        return encoding

In [None]:
import PyPDF2

def pdf_to_text(pdf_path):
    pdf_file = open(pdf_path, 'rb')
    pdf_reader = PyPDF2.PdfReader(pdf_file)
    text = ''
    for page_num in range(len(pdf_reader.pages)):
        page = pdf_reader.pages[page_num]
        text += page.extract_text()
    pdf_file.close()
    return text

def text_to_xml(text, xml_path):
    with open(xml_path, 'w') as xml_file:
        xml_file.write('<root>\n')
        lines = text.split('\n')
        for line in lines:
            xml_file.write(f'  <line>{line}</line>\n')
        xml_file.write('</root>')

# pdf_path = '/content/dov.pdf'
# xml_path = '/content/drive/My Drive/ColabNotebooks/xml_file.xml'

# pdf_text = pdf_to_text(pdf_path)
# text_to_xml(pdf_text, xml_path)

# print(f'PDF content has been converted to XML and saved to {xml_path}')


In [None]:
#test
from bs4 import BeautifulSoup
# Reading the data inside the xml file to a variable under the name data
with open('/content/drive/MyDrive/ColabNotebooks/xml_file.xml', 'r') as f:
    data = f.read()

# Passing the stored data inside the beautifulsoup parser, storing the returned object
Bs_data = BeautifulSoup(data, "xml")
# print(Bs_data)

In [None]:
import glob
from xml.etree import ElementTree as ET
files_tif = glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*tif")
files_gt_xml= glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*_gt.xml")
files_ocr_xml= glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*_ocr.xml")
# print("number of files: {}".format(len(files)))
tif_files=[]
gt_xml_files=[]
ocr_xml_files=[]
for file in files_tif:
  # print(file)
  if file.endswith(".tif"):
    tif_files.append(file)
for file in files_gt_xml:
  gt_xml_files.append(file)
for file in files_ocr_xml:
  ocr_xml_files.append(file)


def get_get_bbox(bbox):
    items = bbox.split(",")
    x1 = int(float(items[0]))
    y1 = int(float(items[1].split(" ")[0]))
    x2 = int(float(items[1].split(" ")[1]))
    y2 = int(float(items[-1]))
    return [x1, y1, x2, y2]

# get_words_bbox(gt_xml_files)

word_list = []
for xml_file_path in ocr_xml_files:
  with open(xml_file_path, encoding="utf8") as f:
        xml_data = f.read()
  soup = BeautifulSoup(xml_data, 'xml')
  page = soup.find_all('Page')
  words = soup.find_all('Word')
  page_attrs = page[0].attrs
  # print(page_attrs)


  for word in words:
      word_dict = {}
      for content in word.contents:
            word_dict.update({"text": word.find("Unicode").get_text()})
            word_dict.update({"Points": word.find("Coords")['points']})
            if isinstance(content, ET.Element):
                word_dict.update(content.attrs)
            word_dict["bbox"] = get_get_bbox(word_dict.get('Points', ''))
            word_dict.pop("Points")
            word_list.append(word_dict)
  # print(page_attrs)
  # print(sorted(word_list, key=lambda x: [x["bbox"][1], x["bbox"][0]]))




In [None]:
# from bs4 import BeautifulSoup
# from xml.etree import ElementTree as ET

# files_tif = glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*tif")
# files_gt_xml= glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*_gt.xml")
# files_ocr_xml= glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*_ocr.xml")
# print("number of files: {}".format(len(files)))
# tif_files=[]
# gt_xml_files=[]
# ocr_xml_files=[]
# for file in files_tif:
#   # print(file)
#   if file.endswith(".tif"):
#     tif_files.append(file)
# for file in files_gt_xml:
#   gt_xml_files.append(file)
# for file in files_ocr_xml:
#   ocr_xml_files.append(file)


def get_bbox(bbox):
    items = bbox.split(",")
    x1 = int(float(items[0]))
    y1 = int(float(items[1].split(" ")[0]))
    x2 = int(float(items[1].split(" ")[1]))
    y2 = int(float(items[-1]))
    return [x1, y1, x2, y2]


def get_words_bbox(ocr_xml_files):
    word_list = []
    for xml_file_path in ocr_xml_files:
        with open(xml_file_path, encoding="utf8") as f:
            xml_data = f.read()
        soup = BeautifulSoup(xml_data, 'xml')
        page = soup.find_all('Page')
        words = soup.find_all('Word')
        page_attrs = page[0].attrs
        page_attrs

        for word in words:
            word_dict = {}
            for content in word.contents:
                word_dict.update({"text": word.find("Unicode").get_text()})
                word_dict.update({"Points": word.find("Coords")['points']})
                if isinstance(content, ET.Element):
                    word_dict.update(content.attrs)
                word_dict["bbox"] = get_bbox(word_dict.get('Points', ''))
                word_dict.pop("Points")
                word_list.append(word_dict)
        # print(page_attrs)
        # print(sorted(word_list, key=lambda x: [x["bbox"][1], x["bbox"][0]]))

    return page_attrs, sorted(word_list, key=lambda x: [x["bbox"][1], x["bbox"][0]])
    # return sorted(word_list)

# # Example usage:
# gt_xml_files = ["path/to/your/xml/file1.xml", "path/to/your/xml/file2.xml"]
# page_attrs, sorted_word_list = get_words_bbox(ocr_xml_files)
# print("Page Attributes:", page_attrs)
# print("Sorted Word List:", sorted_word_list)


In [None]:
def get_label_bbox(gt_xml_files):
  for gt_xml_path in gt_xml_files:
    with open(gt_xml_path, encoding="utf8")as f:
        xml_data = f.read()
    soup = BeautifulSoup(xml_data,'xml')
    word_list = []
    words = soup.find_all('TextRegion')

    word_list = []
    for word in words:
        word_dict = {}
        # word_dict.update({"Points": word.find("Coords")['points']})
        for content in word.contents:
            if isinstance(content,element.Tag):
                word_dict.update(content.attrs)
        word_dict["bbox"] = get_bbox(word_dict["points"])
        word_dict.pop("points")
        word_list.append(word_dict)
    return sorted(word_list, key=lambda x : [x["bbox"][1], x["bbox"][0]])

In [None]:
  #test

  from bs4 import BeautifulSoup, element
  for gt_xml_path in gt_xml_files:
    with open(gt_xml_path, encoding="utf8")as f:
        xml_data = f.read()
    soup = BeautifulSoup(xml_data,'xml')
    word_list = []
    words = soup.find_all('TextRegion')
    # print(words)

    word_list = []
    for word in words:
        # print(word)
        word_dict = {}
        # word_dict.update({"Points": word.find("Coords")['points']})
        for content in word.contents:
            if isinstance(content,element.Tag):
                word_dict.update(content.attrs)
        # print(word_dict)
        word_dict["bbox"] = get_bbox(word_dict["points"])
        word_dict.pop("points")
        word_list.append(word_dict)
    # print(sorted(word_list, key=lambda x : [x["bbox"][1], x["bbox"][0]]))

In [None]:
import pandas as pd

def is_word_bbox_in_label_bbox(word_bbox, label_bbox):
    x1_w,y1_w,x2_w,y2_w = word_bbox
    x1_l,y1_l,x2_l,y2_l = label_bbox
    if x1_w > x1_l and x2_w < x2_l and y1_w > y1_l and y2_w < y2_l:
        return True
    else:
        return False

def assign_lable_to_word(words_bbox_list, word_label_list):
    df_label = pd.DataFrame(word_label_list)
    df_words = pd.DataFrame(words_bbox_list)
    lst_output = []
    for index_word, row_word in df_words.iterrows():
        for index_label, row_label in df_label.iterrows():
            if is_word_bbox_in_label_bbox(row_word["bbox"], row_label["bbox"]):
                row_dict = row_word.to_dict()
                row_dict["label"] = row_label["value"]
                lst_output.append(row_dict)
    return pd.DataFrame(lst_output)






files_tif = glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*tif")
files_gt_xml= glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*_gt.xml")
files_ocr_xml= glob.glob(f"/content/drive/MyDrive/ImageAndXML_Data/*_ocr.xml")
# print("number of files: {}".format(len(files)))
tif_files=[]
gt_xml_files=[]
ocr_xml_files=[]
for file in files_tif:
  # print(file)
  if file.endswith(".tif"):
    tif_files.append(file)
for file in files_gt_xml:
  gt_xml_files.append(file)
for file in files_ocr_xml:
  ocr_xml_files.append(file)

lst_output = []
page_attrs, words_bbox_list = get_words_bbox(ocr_xml_files)
word_label_list = get_label_bbox(gt_xml_files)
df_word_lable = assign_lable_to_word(words_bbox_list, word_label_list)
page_attrs.update({"words" : df_word_lable.text.tolist(), "bbox" : df_word_lable.bbox.tolist(), "label" : df_word_lable.label.tolist()})
lst_output.append(page_attrs)


In [None]:
df = pd.DataFrame(lst_output)[["imageFilename","imageHeight", "imageWidth", "words", "bbox", "label"]]

In [None]:
df.head(3)

In [None]:
df.to_pickle("/content/data.pkl")

In [None]:

data = pd.read_pickle("/content/data.pkl")
data.head()

Unnamed: 0,imageFilename,imageHeight,imageWidth,words,bbox,label
0,2023591606_2023591608.tif,1000,777,"[3/, 3/, 3/, 3/, 3/, 3, 3, 3, 3, 3, y, y, y, y...","[[686, 17, 700, 27], [686, 17, 700, 27], [686,...","[invoice_info, invoice_info, invoice_info, inv..."


In [None]:
print(len(data))

1


In [None]:
import numpy as np
class data_config:
    labels = np.unique([item for sublist in data.label for item in sublist]).tolist()
    labels = sum([["B-" + item, "I-" + item] for item in np.unique(labels)], [])
    num_labels = len(labels)
    id2label = {v: k for v, k in enumerate(labels)}
    label2id = {k: v for v, k in enumerate(labels)}

In [None]:
!pip install transformers



In [None]:
# Install dependencies
!pip install -U torch torchvision

# Install pycocotools
!pip install cython
!pip install 'git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI'




Collecting torch
  Downloading torch-2.1.1-cp310-cp310-manylinux1_x86_64.whl (670.2 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m670.2/670.2 MB[0m [31m1.1 MB/s[0m eta [36m0:00:00[0m
[0mCollecting torchvision
  Downloading torchvision-0.16.1-cp310-cp310-manylinux1_x86_64.whl (6.8 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m6.8/6.8 MB[0m [31m7.1 MB/s[0m eta [36m0:00:00[0m
Collecting nvidia-cuda-nvrtc-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_nvrtc_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (23.7 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m23.7/23.7 MB[0m [31m7.3 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-runtime-cu12==12.1.105 (from torch)
  Downloading nvidia_cuda_runtime_cu12-12.1.105-py3-none-manylinux1_x86_64.whl (823 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m823.6/823.6 kB[0m [31m8.1 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting nvidia-cuda-c

Collecting git+https://github.com/cocodataset/cocoapi.git#subdirectory=PythonAPI
  Cloning https://github.com/cocodataset/cocoapi.git to /tmp/pip-req-build-xhr9fmji
  Running command git clone --filter=blob:none --quiet https://github.com/cocodataset/cocoapi.git /tmp/pip-req-build-xhr9fmji
  Resolved https://github.com/cocodataset/cocoapi.git to commit 8c9bcc3cf640524c4c20a9c40e89cb6a2f2fa0e9
  Preparing metadata (setup.py) ... [?25l[?25hdone
Building wheels for collected packages: pycocotools
  Building wheel for pycocotools (setup.py) ... [?25l[?25hdone
  Created wheel for pycocotools: filename=pycocotools-2.0-cp310-cp310-linux_x86_64.whl size=375525 sha256=8001a61cf1c15710688b1dcf6e62602fbdde4ae48a43b7022fc9c090079f2e2f
  Stored in directory: /tmp/pip-ephem-wheel-cache-sbcz5gy2/wheels/39/61/b4/480fbddb4d3d6bc34083e7397bc6f5d1381f79acc68e9f3511
Successfully built pycocotools
Installing collected packages: pycocotools
  Attempting uninstall: pycocotools
    Found existing installa

In [None]:
# Install Detectron2
!git clone https://github.com/facebookresearch/detectron2.git
!pip install -e detectron2

Cloning into 'detectron2'...
remote: Enumerating objects: 15285, done.[K
remote: Counting objects: 100% (10/10), done.[K
remote: Compressing objects: 100% (10/10), done.[K
remote: Total 15285 (delta 2), reused 5 (delta 0), pack-reused 15275[K
Receiving objects: 100% (15285/15285), 6.18 MiB | 6.68 MiB/s, done.
Resolving deltas: 100% (11117/11117), done.
Obtaining file:///content/detectron2
  Preparing metadata (setup.py) ... [?25l[?25hdone
Collecting pycocotools>=2.0.2 (from detectron2==0.6)
  Downloading pycocotools-2.0.7-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (426 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m426.2/426.2 kB[0m [31m2.6 MB/s[0m eta [36m0:00:00[0m
Collecting yacs>=0.1.8 (from detectron2==0.6)
  Using cached yacs-0.1.8-py3-none-any.whl (14 kB)
Collecting fvcore<0.1.6,>=0.1.5 (from detectron2==0.6)
  Using cached fvcore-0.1.5.post20221221-py3-none-any.whl
Collecting iopath<0.1.10,>=0.1.7 (from detectron2==0.6)
  Using ca

In [None]:
# !python -m pip install pyyaml==5.1

# Note: This is a faster way to install detectron2 in Colab, but it does not include all functionalities (e.g. compiled operators).
# See https://detectron2.readthedocs.io/tutorials/install.html for full installation instructions
# !git clone 'https://github.com/facebookresearch/detectron2'
# dist = distutils.core.run_setup("./detectron2/setup.py")
# !python -m pip install {' '.join([f"'{x}'" for x in dist.install_requires])}
# sys.path.insert(0, os.path.abspath('./detectron2'))


In [None]:
import sys, os, distutils.core
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0
torch:  2.1 ; cuda:  cu118


AttributeError: ignored

In [None]:
import detectron2

In [None]:
# Some basic setup:
# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()

# import some common libraries
import numpy as np
import os, json, cv2, random
from google.colab.patches import cv2_imshow

# import some common detectron2 utilities
from detectron2 import model_zoo
from detectron2.engine import DefaultPredictor
from detectron2.config import get_cfg
from detectron2.utils.visualizer import Visualizer
from detectron2.data import MetadataCatalog, DatasetCatalog

In [None]:
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')

In [None]:
import torch, detectron2
!nvcc --version
TORCH_VERSION = ".".join(torch.__version__.split(".")[:2])
CUDA_VERSION = torch.__version__.split("+")[-1]
print("torch: ", TORCH_VERSION, "; cuda: ", CUDA_VERSION)
print("detectron2:", detectron2.__version__)
from transformers import LayoutLMv2Config, LayoutLMv2Tokenizer, LayoutLMv2ForTokenClassification



# Setup detectron2 logger
import detectron2
from detectron2.utils.logger import setup_logger
setup_logger()



In [None]:
model_path = 'microsoft/layoutlmv2-base-uncased'
config = LayoutLMv2Config.from_pretrained(model_path, num_labels=data_config.num_labels, id2label = data_config.id2label, label2id = data_config.label2id)
tokenizer = LayoutLMv2Tokenizer.from_pretrained(model_path)
model = LayoutLMv2ForTokenClassification.from_pretrained(model_path, config = config)
model.to(device)

In [None]:
from sklearn.model_selection import train_test_split
train, valid = train_test_split(df, test_size = 0.2)

train_dataset = InvoiceDataSet(df = train, tokenizer = tokenizer, max_length = 512, target_size = 224, train=True)
train_dataloader = DataLoader(train_dataset, batch_size=5)

valid_dataset = InvoiceDataSet(df = valid, tokenizer = tokenizer, max_length = 512, target_size = 224, train=False)
valid_dataloader = DataLoader(valid_dataset, batch_size=5)


In [None]:
from transformers import AdamW
from tqdm.notebook import tqdm
import numpy as np
from seqeval.metrics import (
    classification_report,
    f1_score,
    precision_score,
    recall_score,
)
import torch

def train_fn(train_dataloader, model, optimizer):
    tk0 = tqdm(train_dataloader, total = len(train_dataloader))
    for bi, batch in enumerate(tk0):
        input_ids=batch['input_ids'].to(device)
        bbox=batch['bbox'].to(device)
        attention_mask=batch['attention_mask'].to(device)
        token_type_ids=batch['token_type_ids'].to(device)
        labels=batch['labels'].to(device)
        resized_images = batch['resized_image'].to(device)
        resized_and_aligned_bounding_boxes = batch['resized_and_aligned_bounding_boxes'].to(device)
        outputs = model(image = resized_images,input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,labels=labels)
        loss = outputs.loss
        loss.backward()
        optimizer.step()
        optimizer.zero_grad()

def eval_fn(eval_dataloader, model):
    eval_loss = 0.0
    nb_eval_steps = 0
    preds = None
    out_label_ids = None
    model.eval()
    tk0 = tqdm(eval_dataloader, total = len(eval_dataloader))
    for bi, batch in enumerate(tk0):
        with torch.no_grad():
            input_ids=batch['input_ids'].to(device)
            bbox=batch['bbox'].to(device)
            attention_mask=batch['attention_mask'].to(device)
            token_type_ids=batch['token_type_ids'].to(device)
            labels=batch['labels'].to(device)
            resized_images = batch['resized_image'].to(device)
            resized_and_aligned_bounding_boxes = batch['resized_and_aligned_bounding_boxes'].to(device)
            outputs = model(image = resized_images,input_ids=input_ids, bbox=bbox, attention_mask=attention_mask, token_type_ids=token_type_ids,labels=labels)
            tmp_eval_loss = outputs.loss
            logits = outputs.logits
            eval_loss += tmp_eval_loss.item()
            nb_eval_steps += 1
            if preds is None:
                preds = logits.detach().cpu().numpy()
                out_label_ids = labels.detach().cpu().numpy()
            else:
                preds = np.append(preds, logits.detach().cpu().numpy(), axis=0)
                out_label_ids = np.append(
                    out_label_ids, labels.detach().cpu().numpy(), axis=0
                )
    eval_loss = eval_loss / nb_eval_steps
    preds = np.argmax(preds, axis=2)
    out_label_list = [[] for _ in range(out_label_ids.shape[0])]
    preds_list = [[] for _ in range(out_label_ids.shape[0])]
    for i in range(out_label_ids.shape[0]):
        for j in range(out_label_ids.shape[1]):
            if out_label_ids[i, j] != -100:
                out_label_list[i].append(config.id2label[out_label_ids[i][j]])
                preds_list[i].append(config.id2label[preds[i][j]])

    results = {
        "loss": eval_loss,
        "precision": precision_score(out_label_list, preds_list),
        "recall": recall_score(out_label_list, preds_list),
        "f1": f1_score(out_label_list, preds_list),
    }
    return results

In [None]:
MODEL_PATH ="/kaggle/working/pytorch_model.bin"
optimizer = AdamW(model.parameters(), lr=5e-5)
global_step = 0
best_f1_score = 0
for epoch in range(5):
    train_fn(train_dataloader, model, optimizer)
    current_f1_score = eval_fn(valid_dataloader, model)
    if current_f1_score["f1"] > best_f1_score:
        torch.save(model.state_dict(), MODEL_PATH)
        best_f1_score = current_f1_score["f1"]
    print("best_f1_score :", best_f1_score)