# Train and fine-tune an OCR model

I will be using this guide here: https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_Seq2SeqTrainer.ipynb

How to build my image dataset: https://huggingface.co/docs/datasets/en/image_dataset

In [None]:
from datasets import Dataset
import pandas as pd

images_df = pd.read_csv("images/metadata.csv")
images_ds = Dataset.from_pandas(images_df)
images_df

In [None]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

model_name = "microsoft/trocr-base-printed"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

In [None]:
from PIL import Image
from pathlib import Path
import torch

def create_image_and_process_text(item):
    file_name = item["file_name"]
    text = item["text"]

    file_path = Path("images") / file_name

    image = Image.open(file_path).convert("RGB")
    
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.squeeze()

    labels = processor.tokenizer(text, padding="max_length", max_length=16).input_ids
    labels = torch.tensor(labels)

    encoding = { "pixel_values": pixel_values, "labels": labels }
    return encoding

In [None]:
inout_images_ds = images_ds.map(create_image_and_process_text, remove_columns=["file_name", "text"]) 

In [None]:
from operator import itemgetter

train_test_ds = inout_images_ds.train_test_split()
train_dataset, eval_dataset = itemgetter("train", "test")(train_test_ds)
train_dataset, eval_dataset

In [None]:
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# model.config.max_length = 64
# model.config.early_stopping = True
# model.config.no_repeat_ngram_size = 3
# model.config.length_penalty = 2.0
# model.config.num_beams = 4

In [None]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    num_train_epochs=10,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    # fp16=True, 
    output_dir="./",
    # logging_steps=2,
    # save_steps=1000,
    # eval_steps=200,
    report_to="none",
)

In [None]:
from evaluate import load

cer_metric = load("cer")

In [None]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    print(pred_str, label_str)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return { "character_error_rate": cer }

In [None]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

In [None]:
trainer.train()