# LayoutLMv3 Training for SRE

In this notebook, we will try to train LayoutLMv3 for custom SER task dataset. Specifically, we will train the model using the dataset that you have labeled using Label Studio

## Dataset Preparation

First, prepare the dataset. Copy the annotation results from the label studio like below.

In [None]:
!cp -r "/content/drive/MyDrive/Public/Dibimbing/25 - OCR/Assignment/handwriting" "./handwriting"

In [None]:
!pip install accelerate datasets transformers seqeval

In [None]:
import json
from pathlib import Path
from typing import List

from torch.utils.data import Dataset
from transformers import AutoProcessor
from PIL import Image

LBL2ID = {
    "O": 0,
    "NUM": 1,
    "B-QUESTION": 2,
    "I-QUESTION": 3,
    "B-ANSWER": 4,
    "I-ANSWER": 5,
}
ID2LBL = {v: k for k, v in LBL2ID.items()}
LABEL_LIST = list(LBL2ID)

In [None]:
def get_image_name(ls_image_path: Path) -> str:
    """
    Label studio will write the image file name in format of
    '{random_id}-{original_image_name}'. So we only want to
    get the original image name, since that is the name that
    we have.
    """
    name = ls_image_path.name
    name = name[(name.find("-") + 1):]
    return name

def load_annotation_json(annotation, train_dir, val_dir):
    with open(annotation, "r") as f:
        data_raw = json.load(f)
    for d in data_raw:
        d["ocr"] = get_image_name(Path(d["ocr"]))
    train_images = [p.name for p in train_dir.glob("*")]
    val_images = [p.name for p in val_dir.glob("*")]
    train_anno = [ann for ann in data_raw if ann["ocr"] in train_images]
    val_anno = [ann for ann in data_raw if ann["ocr"] in val_images]
    return train_anno, val_anno

def xywh2xyxy(xywh: List[float], img_width: int, img_height: int) -> List[int]:
    """
    Change bounding box format xywh normalized 0-100 to
    xyxy normalized 0-1000.
    """
    x, y, w, h = xywh
    x = x * 10
    y = y * 10
    w = w * 10
    h = h * 10
    return [
        int(x),
        int(y),
        int(x + w),
        int(y + h),
    ]

def extract_box(ls_box):
    xywh = [ls_box["x"], ls_box["y"], ls_box["width"], ls_box["height"]]
    return xywh2xyxy(xywh, ls_box["original_width"], ls_box["original_height"])

def extract_annotation_data(annotation, images_dir):
    converted_ann  = {}
    converted_ann["img_path"] = images_dir / annotation["ocr"]
    converted_ann["boxes"] = [extract_box(lbl) for lbl in annotation["label"]]
    converted_ann["words"] = annotation["transcription"]
    converted_ann["word_labels"] = [LBL2ID[lbl["labels"][0]] for lbl in annotation["label"]]
    return converted_ann

class CustomDataset(Dataset):
    def __init__(self, annotations, processor):
        self.annotations = annotations
        self.processor = processor
        self.dataset = [self._create_feature(ann) for ann in self.annotations]

    def _create_feature(self, ann):
        image = Image.open(ann["img_path"]).convert('RGB')
        encoding = self.processor(
            image,
            ann["words"],
            boxes=ann["boxes"],
            word_labels=ann["word_labels"],
            return_tensors="pt",
            truncation=True,
            padding="max_length",
        )
        encoding = {k:v.squeeze() for k,v in encoding.items()}
        return encoding

    def __getitem__(self, idx):
        return self.dataset[idx]

    def __len__(self):
        return len(self.dataset)

In [None]:
ann_file = Path("/content/handwriting/label-studio-anno.json")
train_dir = Path("/content/handwriting/training")
val_dir = Path("/content/handwriting/test")

train_anns_ls, val_anns_ls =  load_annotation_json(ann_file, train_dir, val_dir)
train_anns = [extract_annotation_data(ann, train_dir) for ann in train_anns_ls]
val_anns = [extract_annotation_data(ann, val_dir) for ann in val_anns_ls]

processor = AutoProcessor.from_pretrained(
    "microsoft/layoutlmv3-base",
    apply_ocr=False,
)
train_dataset = CustomDataset(train_anns, processor)
val_dataset = CustomDataset(val_anns, processor)
combined_dataset = CustomDataset([*train_anns, *val_anns], processor)

In [None]:
train_dataset[0].keys()

In [None]:
for id, label in zip(val_dataset[0]["input_ids"], val_dataset[0]["labels"]):
  print(processor.tokenizer.decode([id]), label.item())

## Training

In [None]:
from datasets import load_metric

metric = load_metric("seqeval")

In [None]:
import numpy as np

return_entity_level_metrics = False

def compute_metrics(p):
    predictions, labels = p
    predictions = np.argmax(predictions, axis=2)

    # Remove ignored index (special tokens)
    true_predictions = [
        [LABEL_LIST[p] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]
    true_labels = [
        [LABEL_LIST[l] for (p, l) in zip(prediction, label) if l != -100]
        for prediction, label in zip(predictions, labels)
    ]

    results = metric.compute(predictions=true_predictions, references=true_labels)
    if return_entity_level_metrics:
        # Unpack nested dictionaries
        final_results = {}
        for key, value in results.items():
            if isinstance(value, dict):
                for n, v in value.items():
                    final_results[f"{key}_{n}"] = v
            else:
                final_results[key] = value
        return final_results
    else:
        return {
            "precision": results["overall_precision"],
            "recall": results["overall_recall"],
            "f1": results["overall_f1"],
            "accuracy": results["overall_accuracy"],
        }

In [None]:
from transformers import LayoutLMv3ForTokenClassification, TrainingArguments, Trainer
from transformers.data.data_collator import default_data_collator

model = LayoutLMv3ForTokenClassification.from_pretrained(
    "microsoft/layoutlmv3-base",
    id2label=ID2LBL,
    label2id=LBL2ID,
)

training_args = TrainingArguments(
    output_dir="test",
    max_steps=100,
    per_device_train_batch_size=2,
    per_device_eval_batch_size=2,
    learning_rate=1e-5,
    evaluation_strategy="steps",
    eval_steps=10,
    metric_for_best_model="f1",
)

trainer = Trainer(
    model=model,
    args=training_args,
    train_dataset=train_dataset,
    eval_dataset=val_dataset,
    tokenizer=processor,
    data_collator=default_data_collator,
    compute_metrics=compute_metrics,
)

trainer.train()