# Create training data

Schema:
* id: (int) unique ID for each doc
* tokens: (list) words or groups of words 
* bboxes: (list) bounding box for each token (x1, y1, x2, y2)
* ner_tags: (int) the entity tag corresponding to each token
* image: PIL Image

See [FUNSD dataset](https://huggingface.co/datasets/nielsr/funsd-layoutlmv3) as an example

Tutorials and notebooks:
* [Fine-tuning on custom dataset tutorial](https://medium.com/@matt.noe/tutorial-how-to-train-layoutlm-on-a-custom-dataset-with-hugging-face-cda58c96571c)
* [Fine-tuning LayoutLMv3 notebook](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLMv3/Fine_tune_LayoutLMv3_on_FUNSD_(HuggingFace_Trainer).ipynb#scrollTo=cqcq7rzlVDOE)
* [Fine-tuning LayoutLMv2 notebook](https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/LayoutLMv2/FUNSD/Fine_tuning_LayoutLMv2ForTokenClassification_on_FUNSD_using_HuggingFace_Trainer.ipynb)

In [None]:
!pip install opencv-python

In [None]:
import json
from typing import Any, Optional, Union

import fitz  # PyMuPDF
import numpy as np
import pandas as pd
from PIL import Image
from pathlib import Path
import cv2
import matplotlib.pyplot as plt
from datasets import Dataset, ClassLabel, Features, Sequence, Value, Array2D, Array3D, load_metric
import torch
from transformers import AutoProcessor

from mozilla_sec_eia.utils import GCSArchive

# PDF text extraction utility functions

In [None]:
# copied from well gas project wellgas/features/extract_text.py
def extract_pdf_data_from_page(page: fitz.Page) -> dict[str, pd.DataFrame]:
    """Parse PDF page data."""
    contents = _parse_page_contents(page)
    meta = {
        "rotation_degrees": [page.rotation],
        "origin_x_pdf_coord": [page.rect[0]],
        "origin_y_pdf_coord": [page.rect[1]],
        "width_pdf_coord": [page.rect[2] - page.rect[0]],
        "height_pdf_coord": [page.rect[3] - page.rect[1]],
        "has_images": [not contents["image"].empty],
        "has_text": [not contents["pdf_text"].empty],
        "page_num": [page.number],
    }
    if not contents["image"].empty:
        img_area = (
            contents["image"]
            .eval(
                "((bottom_right_x_pdf - top_left_x_pdf)"
                " * (bottom_right_y_pdf - top_left_y_pdf))"
            )
            .sum()
        )
    else:
        img_area = 0
    total_area = meta["width_pdf_coord"][0] * meta["height_pdf_coord"][0]

    meta["image_area_frac"] = [np.float32(img_area / total_area)]
    meta_df = pd.DataFrame(meta).astype(
        {
            "rotation_degrees": np.int16,
            "origin_x_pdf_coord": np.float32,
            "origin_y_pdf_coord": np.float32,
            "width_pdf_coord": np.float32,
            "height_pdf_coord": np.float32,
            "has_images": "boolean",
            "has_text": "boolean",
            "page_num": np.int16,
            "image_area_frac": np.float32,
        }
    )
    meta = dict(page=meta_df)
    for df in contents.values():  # add ID fields
        if not df.empty:
            df["page_num"] = np.int16(page.number)
    return contents | meta


def _parse_page_contents(page: fitz.Page) -> dict[str, pd.DataFrame]:
    """Parse page contents using fitz.TextPage."""
    flags = fitz.TEXTFLAGS_DICT
    # try getting only words
    textpage = page.get_textpage(flags=flags)
    content = textpage.extractDICT()
    words = textpage.extractWORDS()
    images = []
    text = []
    for block in content["blocks"]:
        if block["type"] == 0:
            # skip over text, we'll parse it by word blocks
            continue
        elif block["type"] == 1:
            images.append(_parse_image_block(block))
        else:
            raise ValueError(f"Unknown block type: {block['type']}")
    for word_block in words:
        parsed = _parse_word_block(word_block)
        if not parsed.empty:
            text.append(parsed)
    if text:
        text = pd.concat(text, axis=0, ignore_index=True)
    else:
        text = pd.DataFrame()
    if images:
        images = pd.concat(
            (pd.DataFrame(image) for image in images), axis=0, ignore_index=True
        )
    else:
        images = pd.DataFrame()
        
    return dict(pdf_text=text, image=images)


def _parse_image_block(img_block: dict[str, Any]) -> pd.DataFrame:
    """Parse an image block from a fitz.TextPage.extractDICT() output."""
    top_left_x_pdf, top_left_y_pdf, bottom_right_x_pdf, bottom_right_y_pdf = img_block[
        "bbox"
    ]
    dpi = min(
        img_block["xres"], img_block["yres"]
    )  # should be equal; min() just in case
    out = pd.DataFrame(
        {
            "img_num": [img_block["number"]],
            "dpi": [dpi],
            "top_left_x_pdf": [top_left_x_pdf],
            "top_left_y_pdf": [top_left_y_pdf],
            "bottom_right_x_pdf": [bottom_right_x_pdf],
            "bottom_right_y_pdf": [bottom_right_y_pdf],
        }
    ).astype(
        {
            "img_num": np.int16,
            "dpi": np.int16,
            "top_left_x_pdf": np.float32,
            "top_left_y_pdf": np.float32,
            "bottom_right_x_pdf": np.float32,
            "bottom_right_y_pdf": np.float32,
        }
    )
    return out

def _parse_word_block(word_block: tuple) -> pd.DataFrame:
    """Parse a word block from a fitz.TextPage.extractWORDS() output."""
    out = {
        "top_left_x_pdf": [word_block[0]],
        "top_left_y_pdf": [word_block[1]],
        "bottom_right_x_pdf": [word_block[2]],
        "bottom_right_y_pdf": [word_block[3]],
        "text": [word_block[4]],
        "block_num": [word_block[5]],
        "line_num": [word_block[6]],
        "word_num": [word_block[7]]
    }
    out = pd.DataFrame(out).astype(
        {
            "block_num": np.int16,
            "line_num": np.int16,
            "word_num": np.int16,
            "text": "string",
            "top_left_x_pdf": np.float32,
            "top_left_y_pdf": np.float32,
            "bottom_right_x_pdf": np.float32,
            "bottom_right_y_pdf": np.float32,
        }
    )
    return out

def _frac_normal_ascii(text: Union[str, bytes]) -> float:
    """Fraction of characters that are normal ASCII characters."""
    # normal characters, from space to tilde, plus whitespace
    # see https://www.asciitable.com/
    sum_ = 0
    if isinstance(text, bytes):
        text = text.decode("utf-8")
    for char in text:
        if (32 <= ord(char) <= 126) or char in "\t\n":
            sum_ += 1
    return sum_ / len(text)


In [None]:

def _render_page(
    pg: fitz.Page, dpi=150, clip: Optional[fitz.Rect] = None
) -> Image.Image:
    """Render a page of a PDF as a PIL.Image object.

    Args:
        pg (fitz.Page): a page of a PDF
        dpi (int, optional): image resolution in pixels per inch. Defaults to 150.
        clip (Optional[fitz.Rect], optional): Optionally render only a subset of the
            page. Defined in PDF coordinates. Defaults to None, which renders the
            full page.

    Returns:
        Image.Image: PDF page rendered as a PIL.Image object
    """
    # 300 dpi is what tesseract recommends. PaddleOCR seems to do fine with half that.
    render: fitz.Pixmap = pg.get_pixmap(dpi=dpi, clip=clip)  # type: ignore
    img = _pil_img_from_pixmap(render)
    return img


def _pil_img_from_pixmap(pix: fitz.Pixmap) -> Image.Image:
    """Convert pyMuPDF Pixmap object to PIL.Image object.

    For some reason pyMuPDF (aka fitz) lets you save images using PIL, but does not
    have any function to convert to PIL objects. Clearly they do this conversion
    internally; they should just expose it. Instead, I had to copy it out from their
    source code.

    Args:
        pix (fitz.Pixmap): a rendered Pixmap

    Returns:
        Image: a PIL.Image object
    """
    # pyMuPDF source code on GitHub is all in SWIG (some kind of C to python code
    # generator) and is unreadable to me. So you have to inspect your local .py files.
    # Adapted from the Pixmap.pil_save method in python3.9/site-packages/fitz/fitz.py
    # I just replaced instances of "self" with "pix"
    cspace = pix.colorspace
    if cspace is None:
        mode = "L"
    elif cspace.n == 1:
        mode = "L" if pix.alpha == 0 else "LA"
    elif cspace.n == 3:
        mode = "RGB" if pix.alpha == 0 else "RGBA"
    else:
        mode = "CMYK"

    img = Image.frombytes(mode, (pix.width, pix.height), pix.samples)
    return img

In [None]:
PDF_POINTS_PER_INCH = 72  # I believe this is standard for all PDFs

def pil_to_cv2(image: Image.Image) -> np.ndarray:  # noqa: C901
    """Convert a PIL Image to an OpenCV image (numpy array)."""
    # copied from https://gist.github.com/panzi/1ceac1cb30bb6b3450aa5227c02eedd3
    # This covers the common modes, is not exhaustive.
    mode = image.mode
    new_image: np.ndarray
    if mode == "1":
        new_image = np.array(image, dtype=np.uint8)
        new_image *= 255
    elif mode == "L":
        new_image = np.array(image, dtype=np.uint8)
    elif mode == "LA" or mode == "La":
        new_image = np.array(image.convert("RGBA"), dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
    elif mode == "RGB":
        new_image = np.array(image, dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
    elif mode == "RGBA":
        new_image = np.array(image, dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
    elif mode == "LAB":
        new_image = np.array(image, dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_LAB2BGR)
    elif mode == "HSV":
        new_image = np.array(image, dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2BGR)
    elif mode == "YCbCr":
        # XXX: not sure if YCbCr == YCrCb
        new_image = np.array(image, dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_YCrCb2BGR)
    elif mode == "P" or mode == "CMYK":
        new_image = np.array(image.convert("RGB"), dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGB2BGR)
    elif mode == "PA" or mode == "Pa":
        new_image = np.array(image.convert("RGBA"), dtype=np.uint8)
        new_image = cv2.cvtColor(new_image, cv2.COLOR_RGBA2BGRA)
    else:
        raise ValueError(f"unhandled image color mode: {mode}")

    return new_image


def cv2_to_pil(img: np.ndarray) -> Image.Image:
    """Create PIL Image from numpy pixel array."""
    if len(img.shape) == 2:  # single channel, AKA grayscale
        return Image.fromarray(img)
    else:  # only handle BGR for now
        return Image.fromarray(cv2.cvtColor(img, cv2.COLOR_BGR2RGB))


def display_img_array(img: np.ndarray, figsize=(5, 5), **kwargs):
    """Plot image array for jupyter sessions."""
    plt.figure(figsize=figsize)
    if len(img.shape) == 2:  # grayscale
        return plt.imshow(img, cmap="gray", vmin=0, vmax=255, **kwargs)
    else:
        return plt.imshow(cv2.cvtColor(img, cv2.COLOR_BGR2RGB), **kwargs)


def overlay_bboxes(
    img: np.ndarray, bboxes: np.ndarray, color=(255, 0, 0)
) -> np.ndarray:
    """Overlay bounding boxes of shape N x 4 (x0, y0, x1, y1) on an image."""
    img = img.copy()
    for box in np.round(bboxes, 0).astype(np.int32):  # float to int just in case:
        x0, y0, x1, y1 = box
        cv2.rectangle(img, (x0, y0), (x1, y1), color=color, thickness=1)
    return img


def pdf_coords_to_pixel_coords(coords: np.ndarray, dpi: int) -> np.ndarray:
    """Convert PDF coordinates to pixel coordinates."""
    # For arbitrary PDFs you would need to subtract the origin in PDF coordinates,
    # but since you create these PDFs, you know the origin is (0, 0).
    out = coords * dpi / PDF_POINTS_PER_INCH
    return out

# Create PDFs of Ex. 21's

In [None]:
archive = GCSArchive()
md = archive.get_metadata()

In [None]:
tucson_md = md[md["Company Name"].str.contains("TUCSON")].iloc[[0]]
tucson_filing = archive.get_filings(tucson_md)

In [None]:
with open("tucson_electric.pdf", "wb") as file:
    tucson_filing[0].ex_21.save_as_pdf(file)

In [None]:
# get some random PDFs
sample = md[md.exhibit_21_version.notnull()].sample(3)
filings = archive.get_filings(sample)
for i in range(len(filings)):
    with open(f"test_{i}.pdf", "wb") as file:
        filings[i].ex_21.save_as_pdf(file)

# Demo with one doc

In [None]:
src_path = Path("./wisconsin_electric.pdf")
assert src_path.exists()

In [None]:
# from file
doc = fitz.Document(str(src_path))
doc.is_pdf

In [None]:
# from bytes
_bytes = src_path.read_bytes()
from io import BytesIO
doc = fitz.open(stream=BytesIO(_bytes), filetype="pdf")
doc.is_pdf

### Extract Text Bboxes

In [None]:
pg = doc[0]
extracted = extract_pdf_data_from_page(pg)
extracted.keys()

In [None]:
txt = extracted['pdf_text']
img_info = extracted['image']
pg_meta = extracted['page']
txt.shape, img_info.shape, pg_meta.shape

In [None]:
txt

### Label the entities

In [None]:
# do we actually need the headers?
id_to_label_full = {
    0: "O",
    1: "B-Header_Subsidiary",
    2: "I-Header_Subsidiary",
    3: "B-Body_Subsidiary",
    4: "I-Body_Subsidiary",
    5: "B-Header_Loc",
    6: "I-Header_Loc",
    7: "B-Body_Loc",
    8: "I-Body_Loc",
    9: "B-Header_Own_Per",
    10: "I-Header_Own_Per",
    11: "B-Body_Own_Per",
    12: "I-Body_Own_Per"
}
id_to_label_small = {
    0: "O",
    1: "B-Subsidiary",
    2: "I-Subsidiary",
    3: "B-Loc",
    4: "I-Loc",
    5: "B-Own_Per",
    6: "I-Own_Per"
}


In [None]:
label_col = "ner_tag"

In [None]:
txt.loc[:, label_col] = 0

In [None]:
pg_img = _render_page(pg, dpi=50)  # small dpi for notebook display
pg_img

In [None]:
full_pg_img = _render_page(pg)  # full image for dataset

In [None]:
image_filename = "wisconsin_electric.png"

In [None]:
full_pg_img.save(image_filename)

In [None]:
id_to_label_small

In [None]:
txt.iloc[70:90]

In [None]:
label_to_indices = {
    1: [62, 68, 75],
    2: [63, 64, 69, 70, 71, 76],
    3: [66, 73, 77],
    5: [67, 74, 78]
}   

In [None]:
for label, indices in label_to_indices.items():
    txt.loc[indices, label_col] = label
txt[label_col] = txt[label_col].astype(int)

In [None]:
id_to_color = {
    0: (128, 128, 128),
    1: (0, 128, 0),
    2: (0, 255, 0),
    3: (255, 0, 255),
    4: (128, 0, 128),
    5: (128, 128, 0),
    6: (255, 0, 0)
}   

In [None]:
# fix to overlay boxes on pdf and color code it by label
coord_cols = ['top_left_x_pdf', 'top_left_y_pdf', 'bottom_right_x_pdf', 'bottom_right_y_pdf']
dpi = 70
display_img = np.array(_render_page(pg, dpi=dpi))
for tag in id_to_label_small.keys():
    subset = txt[txt[label_col] == tag]
    bboxes = pdf_coords_to_pixel_coords(subset[coord_cols].values, dpi=dpi)
    display_img = overlay_bboxes(display_img, bboxes, color=id_to_color[tag])
display_img_array(display_img, figsize=(10, 10))

In [None]:
txt.loc[:, "id"] = 0
output_df = pd.DataFrame()
output_df = pd.concat([output_df, txt])

In [None]:
output_df

# Put into demo doc into JSON Format for Label Studio

In [None]:
pil_to_cv2(full_pg_img).shape

In [None]:
txt.iloc[62]

In [None]:
x_norm = 100/pg_meta.width_pdf_coord.iloc[0]
y_norm = 100/pg_meta.height_pdf_coord.iloc[0]
x_norm, y_norm

In [None]:
annotation_json = {
   "data": {
      "ocr": "gs://labeled-ex21-filings/wisconsin_electric.png" # how to fill with an f string?
   },
   "annotations": [],
   "predictions": [
      {
         "model_version": "v1.0",
         "result": [
            {
              "original_width": 1241,
              "original_height": 1754,
              "image_rotation": 0,
              "value": {
                "x": 31.34646 * x_norm,
                "y": 349.576477 * y_norm,
                "width": 20.000002 * x_norm,
                "height": 13.73999 * y_norm,
                "rotation": 0
              },
                "id": "bb1",
                "from_name": "bbox",
                "to_name": "image",
                "type": "rectangle",
                "origin": "manual"
            },
            {
               "original_width": 1241,
               "original_height": 1754,
               "image_rotation": 0,
               "value": {
                  "x": 31.34646 * x_norm,
                  "y": 349.576477 * y_norm,
                  "width": 20.000002 * x_norm,
                  "height": 13.73999 * y_norm,
                  "rotation": 0,
                  "text": [
                     "ATC"
                  ]
               },
                "id": "bb1",
                "from_name": "transcription",
                "to_name": "image",
                "type": "textarea",
                "origin": "manual"
            },
         ],
      }
    ],
}

In [None]:
with open('wisconsin_electric.json', 'w') as fp:
    json.dump(annotation_json, fp)

# Label Another

In [None]:
src_path = Path("./test_0.pdf")
assert src_path.exists()

In [None]:
# from file
doc = fitz.Document(str(src_path))
doc.is_pdf

In [None]:
# from bytes
_bytes = src_path.read_bytes()
from io import BytesIO
doc = fitz.open(stream=BytesIO(_bytes), filetype="pdf")
doc.is_pdf

In [None]:
pg2 = doc[0]
extracted = extract_pdf_data_from_page(pg2)
extracted.keys()

In [None]:
txt2 = extracted['pdf_text']
img_info = extracted['image']
pg_meta = extracted['page']
txt2.shape, img_info.shape, pg_meta.shape

In [None]:
txt2.loc[:, label_col] = 0

In [None]:
pg_img2 = _render_page(pg2, dpi=50)  # small dpi for notebook display
pg_img2

In [None]:
full_pg_img2 = _render_page(pg2)  # full image for dataset

In [None]:
id_to_label_small

In [None]:
txt2

In [None]:
label_to_indices = {
    1: [12, 17, 23, 29, 33],
    2: [13, 14, 18, 19, 20, 24, 25, 26, 30, 31, 34],
    3: [15, 22, 28, 32, 35],
    4: [16],
    5: [21, 27]
}

In [None]:
for label, indices in label_to_indices.items():
    txt2.loc[indices, label_col] = label
txt2[label_col] = txt2[label_col].astype(int)

In [None]:
# fix to overlay boxes on pdf and color code it by label
coord_cols = ['top_left_x_pdf', 'top_left_y_pdf', 'bottom_right_x_pdf', 'bottom_right_y_pdf']
dpi = 70
display_img = np.array(_render_page(pg2, dpi=dpi))
for tag in id_to_label_small.keys():
    subset = txt2[txt2[label_col] == tag]
    bboxes = pdf_coords_to_pixel_coords(subset[coord_cols].values, dpi=dpi)
    display_img = overlay_bboxes(display_img, bboxes, color=id_to_color[tag])
display_img_array(display_img, figsize=(10, 10))

In [None]:
txt2.loc[:, "id"] = 1
output_df = pd.concat([output_df, txt2])

# Combine into one dataframe

In [None]:
# for now make image a filename
output_df = pd.DataFrame(columns=["tokens", "bboxes", "ner_tags", "image"])

In [None]:
doc_id = 0

In [None]:
txt.loc[:, "id"] = doc_id

In [None]:
output_df["tokens"] = txt.groupby("id")["text"].apply(list)
output_df["ner_tags"] = txt.groupby("id")["ner_tag"].apply(list)

In [None]:
output_df.loc[[0]]

In [None]:
bbox_cols = ["top_left_x_pdf", "top_left_y_pdf", "bottom_right_x_pdf", "bottom_right_y_pdf"]
matrices = {}
for _, group in txt.groupby('id'):
    bbox = group[bbox_cols].values
    matrix = np.reshape(bbox, (len(group), len(bbox_cols)))
    matrices[group.iloc[0]['id']] = matrix

In [None]:
# this isn't so nice, what does the model want as input?
for i, arr in matrices.items():
    output_df.loc[i, "bboxes"] = arr

In [None]:
# format bboxes to have x1, y1, x2, y2
txt["bboxes"] = txt[["top_left_x_pdf", "top_left_y_pdf", "bottom_right_x_pdf", "bottom_right_y_pdf"]].astype(str).agg(', '.join, axis=1)

In [None]:
output_df["bboxes"] = txt.groupby("id")["bboxes"].apply(list)

In [None]:
output_df

# Try vendoring DocAI
https://medium.com/@matt.noe/tutorial-how-to-train-layoutlm-on-a-custom-dataset-with-hugging-face-cda58c96571c

In [None]:
bbox_cols = ["top_left_x_pdf", "top_left_y_pdf", "bottom_right_x_pdf", "bottom_right_y_pdf"]

In [None]:
# convert dataframe/dictionary into NER format
# document_annotation_to_ner https://github.com/butlerlabs/docai/blob/main/docai/annotations/ner_utils.py
# complete dataset is a list of dicts, with one dict for each doc
n_docs = 2
ner_annotations = []
images = [full_pg_img, full_pg_img2]
for i in range(n_docs):
        annotation = {
            "id": i, 
            "tokens": output_df.groupby("id")["text"].apply(list).loc[i], 
            "ner_tags": [id_to_label_small[n] for n in output_df.groupby("id")["ner_tag"].apply(list).loc[i]],
            "bboxes": output_df.loc[output_df["id"] == i, :][bbox_cols].values.tolist(),
            "image": images[i]
        }
        ner_annotations.append(annotation)

In [None]:
# normalize NER notation for LayoutLM
# https://github.com/butlerlabs/docai/blob/main/docai/annotations/layoutlm_utils.py
def normalize_bounding_box(bbx):
    normalize_bounding_box = list(map(lambda point: int(point * 1000), bbx))
    return normalize_bounding_box

def normalize_ner_annotation_for_layoutlm(annotation):
    """
    Normalize the bounding boxes by 1000 to match LayoutLM expected bounding box format
    """
    normalized_bbxs = list(map(normalize_bounding_box, annotation["bboxes"]))
    return {
        "id": annotation["id"],
        "tokens": annotation["tokens"],
        # Normalize NER bounding boxes by 1000 as LayoutLM expects
        "bboxes": normalized_bbxs,
        "ner_tags": annotation["ner_tags"],
        "image": annotation["image"],
    }

In [None]:
# not sure if we need to actually normalize, Hugging Face expects bboxes between 0 and 1000
# and they appear to already be that way
# norm_ner_annotations = normalize_ner_annotation_for_layoutlm(ner_annotations)

In [None]:
dataset = Dataset.from_list(ner_annotations)

In [None]:
dataset

# Fine-tune Layout LM model

In [None]:
label_list = list(id_to_label_small.values())
label_list

In [None]:
column_names = dataset.column_names

In [None]:
id_to_label_small

In [None]:
id2label = {k: v for k,v in enumerate(label_list)}
label2id = {v: k for k,v in enumerate(label_list)}

In [None]:
class_label = ClassLabel(names=label_list)

In [None]:
# update this split size when there are more than 2 docs
dataset = dataset.train_test_split(test_size=0.5)

In [None]:
# processor = LayoutLMv3Processor.from_pretrained("microsoft/layoutlmv3-base-uncased")
# we'll use the Auto API here - it will load LayoutLMv3Processor behind the scenes,
# based on the checkpoint we provide from the hub
processor = AutoProcessor.from_pretrained("microsoft/layoutlmv3-base", apply_ocr=False)

def convert_ner_tags_to_id(ner_tags):
    return [label2id[ner_tag] for ner_tag in ner_tags]

# This function is used to put the Dataset in its final format for training LayoutLM
def prepare_dataset(annotations):
    images = annotations['image']
    words = annotations['tokens']
    boxes = annotations['bboxes']
    # Map over labels and convert to numeric id for each ner_tag
    ner_tags = [convert_ner_tags_to_id(ner_tags) for ner_tags in annotations['ner_tags']]

    encoding = processor(images, words, boxes=boxes, word_labels=ner_tags, truncation=True, padding="max_length")

    return encoding

In [None]:
# Define features for use training the model 
features = Features({
    'pixel_values': Array3D(dtype="float32", shape=(3, 224, 224)),
    'input_ids': Sequence(feature=Value(dtype='int64')),
    'attention_mask': Sequence(Value(dtype='int64')),
    'bbox': Array2D(dtype="int64", shape=(512, 4)),
    'labels': Sequence(feature=Value(dtype='int64')),
})

# Prepare our train & eval dataset

train_dataset = dataset["train"].map(
    prepare_dataset,
    batched=True,
    remove_columns=column_names,
    features=features,
)

eval_dataset = dataset["test"].map(
    prepare_dataset,
    batched=True,
    remove_columns=column_names,
    features=features,
)

In [None]:
example = train_dataset[0]
processor.tokenizer.decode(example["input_ids"])

In [None]:
train_dataset.set_format("torch")

In [None]:
example = train_dataset[0]
for k,v in example.items():
    print(k,v.shape)

In [None]:
processor.tokenizer.decode(eval_dataset[0]["input_ids"])

In [None]:
metric = load_metric("seqeval")

In [None]:
return_entity_level_metrics = False

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [label_list[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [label_list[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

In [None]:
from transformers import LayoutLMv3ForTokenClassification

model = LayoutLMv3ForTokenClassification.from_pretrained("microsoft/layoutlmv3-base",
                                                         id2label=id2label,
                                                         label2id=label2id)

In [None]:
from transformers import TrainingArguments, Trainer

training_args = TrainingArguments(output_dir="test",
                                  max_steps=1000,
                                  per_device_train_batch_size=2,
                                  per_device_eval_batch_size=2,
                                  learning_rate=1e-5,
                                  evaluation_strategy="steps",
                                  eval_steps=100,
                                  load_best_model_at_end=True,
                                  metric_for_best_model="f1")

In [None]:
from transformers.data.data_collator import default_data_collator

# Initialize our Trainer
trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    tokenizer=processor,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

In [None]:
trainer.train()