# Train and fine-tune an OCR model

I will be using this guide here: https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_Seq2SeqTrainer.ipynb

How to build my image dataset: https://huggingface.co/docs/datasets/en/image_dataset

In [1]:
from datasets import Dataset
import pandas as pd

images_df = pd.read_csv("images/metadata.csv")
images_ds = Dataset.from_pandas(images_df)

In [2]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

model_name = "microsoft/trocr-small-printed"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-small-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [3]:
from PIL import Image
from pathlib import Path
import torch

def create_image_and_process_text(item):
    file_name = item["file_name"]
    text = item["text"]

    file_path = Path("images") / file_name

    image = Image.open(file_path).convert("RGB")
    
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.squeeze()

    labels = processor.tokenizer(text, padding="max_length", max_length=128).input_ids
    labels = [label if label != processor.tokenizer.pad_token_id else -100 for label in labels]
    labels = torch.tensor(labels)

    encoding = { "pixel_values": pixel_values, "labels": labels }
    return encoding

In [4]:
images_ds = images_ds.map(create_image_and_process_text, batched=False) 

Map:   0%|          | 0/8 [00:00<?, ? examples/s]

In [5]:
from operator import itemgetter

train_test_ds = images_ds.train_test_split()
train_dataset, eval_dataset = itemgetter("train", "test")(train_test_ds)
train_dataset, eval_dataset

(Dataset({
     features: ['file_name', 'text', 'pixel_values', 'labels'],
     num_rows: 6
 }),
 Dataset({
     features: ['file_name', 'text', 'pixel_values', 'labels'],
     num_rows: 2
 }))

In [6]:
# set special tokens used for creating the decoder_input_ids from the labels
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id
# make sure vocab size is set correctly
model.config.vocab_size = model.config.decoder.vocab_size

# set beam search parameters
model.config.eos_token_id = processor.tokenizer.sep_token_id
model.config.max_length = 64
model.config.early_stopping = True
model.config.no_repeat_ngram_size = 3
model.config.length_penalty = 2.0
model.config.num_beams = 4

In [19]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    num_train_epochs=10,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    # fp16=True, 
    output_dir="./",
    logging_steps=2,
    save_steps=1000,
    eval_steps=200,
)

In [20]:
from datasets import load_metric

cer_metric = load_metric("cer")

You can avoid this message in future by passing the argument `trust_remote_code=True`.
Passing `trust_remote_code=True` will be mandatory to load this metric from the next major release of `datasets`.


In [21]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions

    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    labels_ids[labels_ids == -100] = processor.tokenizer.pad_token_id
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    print(pred_str, label_str)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return { "character_error_rate": cer }

In [22]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

dataloader_config = DataLoaderConfiguration(dispatch_batches=None, split_batches=False, even_batches=True, use_seedable_sampler=True)


In [None]:
trainer.train();

Epoch,Training Loss,Validation Loss,Character Error Rate
1,8.6693,8.483727,0.83685
2,7.8811,8.267612,0.818565
3,7.8995,8.262975,0.821378
4,7.3805,8.245132,0.734177
5,7.2591,8.258827,0.857947
6,7.2209,8.238214,0.842475


['portion invention invention amount amount amount method amount amount role amount amount', 'portion invention amount amount amount method amount amount'] ['combination oven stock elevator law variety revolution orange inside insect signature championship sense memory promotion class magazine meat importance practice elevator aunt university equipment issue confusion distribution speech recording medium comparison tradition movie volume home case addition table boss college top guide while night storage', 'language metal sympathy boss map studio assumption cash bath chocolate work city philosophy percentage activity community attitude celebration university machine society sun customer video interest agency egg salad surgery boss director instance contact football criticism course hotel bottom boat construction hair amount period story condition engine library']
['information vehicle vehicle vehicle hair hair hair vehicle hair vehicle vehicle school hair hair school hair vehicle schoo