# Train and fine-tune an OCR model

I will be using this guide here: https://colab.research.google.com/github/NielsRogge/Transformers-Tutorials/blob/master/TrOCR/Fine_tune_TrOCR_on_IAM_Handwriting_Database_using_Seq2SeqTrainer.ipynb

How to build my image dataset: https://huggingface.co/docs/datasets/en/image_dataset

I read this guide on setting the training parameters (see Seq2SeqTrainer part of the guide as well as the notebook) https://github.com/philschmid/document-ai-transformers/blob/main/training/donut_sroie.ipynb

In [1]:
!pip install evaluate jiwer

Collecting evaluate
  Downloading evaluate-0.4.1-py3-none-any.whl (84 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m84.1/84.1 kB[0m [31m5.9 MB/s[0m eta [36m0:00:00[0m
[?25hCollecting jiwer
  Downloading jiwer-3.0.3-py3-none-any.whl (21 kB)
Collecting rapidfuzz<4,>=3
  Downloading rapidfuzz-3.8.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl (3.4 MB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m3.4/3.4 MB[0m [31m24.3 MB/s[0m eta [36m0:00:00[0m00:01[0m00:01[0m
Installing collected packages: rapidfuzz, jiwer, evaluate
Successfully installed evaluate-0.4.1 jiwer-3.0.3 rapidfuzz-3.8.1
[0m

In [2]:
from datasets import Dataset
import pandas as pd

images_df = pd.read_csv("images/metadata.csv")
images_ds = Dataset.from_pandas(images_df)
images_df

Unnamed: 0,file_name,text
0,addition_application.png,addition application coffee trade environment ...
1,case_week.png,case week group recognition transportation com...
2,combination_school.png,combination school responsibility average year...
3,concept_administration.png,concept administration sign device reaction so...
4,dealer_passenger.png,dealer passenger tension intention responsibil...
5,discussion_guidance.png,discussion guidance language interest light re...
6,efficiency_answer.png,efficiency answer chance ability intention tim...
7,engine_paper.png,engine paper intention category reaction incom...
8,hearing_winner.png,hearing winner book building importance improv...
9,manufacturer_care.png,manufacturer care page charity piano nothing r...


In [3]:
from transformers import TrOCRProcessor, VisionEncoderDecoderModel

model_name = "microsoft/trocr-base-printed"
processor = TrOCRProcessor.from_pretrained(model_name)
model = VisionEncoderDecoderModel.from_pretrained(model_name)

Downloading preprocessor_config.json:   0%|          | 0.00/228 [00:00<?, ?B/s]

Downloading tokenizer_config.json:   0%|          | 0.00/1.09k [00:00<?, ?B/s]

Downloading vocab.json:   0%|          | 0.00/878k [00:00<?, ?B/s]

Downloading merges.txt:   0%|          | 0.00/446k [00:00<?, ?B/s]

Downloading special_tokens_map.json:   0%|          | 0.00/772 [00:00<?, ?B/s]

Downloading config.json:   0%|          | 0.00/4.03k [00:00<?, ?B/s]

Downloading pytorch_model.bin:   0%|          | 0.00/1.24G [00:00<?, ?B/s]

Some weights of VisionEncoderDecoderModel were not initialized from the model checkpoint at microsoft/trocr-base-printed and are newly initialized: ['encoder.pooler.dense.bias', 'encoder.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.


In [4]:
from PIL import Image
from pathlib import Path
import torch

def create_image_and_process_text(item):
    file_name = item["file_name"]
    text = item["text"]

    file_path = Path("images") / file_name

    image = Image.open(file_path).convert("RGB")
    
    pixel_values = processor(image, return_tensors="pt").pixel_values
    pixel_values = pixel_values.squeeze()

    labels = processor.tokenizer(text, padding="max_length", max_length=16).input_ids
    labels = torch.tensor(labels)

    encoding = { "pixel_values": pixel_values, "labels": labels }
    return encoding

In [5]:
inout_images_ds = images_ds.map(create_image_and_process_text, remove_columns=["file_name", "text"]) 

  0%|          | 0/24 [00:00<?, ?ex/s]

In [6]:
from operator import itemgetter

train_test_ds = inout_images_ds.train_test_split()
train_dataset, eval_dataset = itemgetter("train", "test")(train_test_ds)
train_dataset, eval_dataset

(Dataset({
     features: ['pixel_values', 'labels'],
     num_rows: 18
 }),
 Dataset({
     features: ['pixel_values', 'labels'],
     num_rows: 6
 }))

In [7]:
model.config.decoder_start_token_id = processor.tokenizer.cls_token_id
model.config.pad_token_id = processor.tokenizer.pad_token_id

model.config.vocab_size = model.config.decoder.vocab_size
model.config.eos_token_id = processor.tokenizer.sep_token_id

# model.config.max_length = 16
# model.config.early_stopping = True
# model.config.no_repeat_ngram_size = 3
# model.config.length_penalty = 2.0
# model.config.num_beams = 4

In [8]:
from transformers import Seq2SeqTrainer, Seq2SeqTrainingArguments

training_args = Seq2SeqTrainingArguments(
    predict_with_generate=True,
    num_train_epochs=16,
    learning_rate=2e-5,
    weight_decay=0.01,
    # logging_steps=100,
    evaluation_strategy="epoch",
    per_device_train_batch_size=1,
    per_device_eval_batch_size=1,
    fp16=True,
    output_dir="./",
    # logging_steps=2,
    # save_steps=1000,
    # eval_steps=200,
    report_to="none",
)

In [9]:
from evaluate import load

cer_metric = load("cer")

Downloading builder script:   0%|          | 0.00/5.60k [00:00<?, ?B/s]

In [10]:
def compute_metrics(pred):
    labels_ids = pred.label_ids
    pred_ids = pred.predictions
    
    pred_str = processor.batch_decode(pred_ids, skip_special_tokens=True)
    label_str = processor.batch_decode(labels_ids, skip_special_tokens=True)

    print(pred_str, label_str)

    cer = cer_metric.compute(predictions=pred_str, references=label_str)

    return { "character_error_rate": cer }

In [11]:
from transformers import default_data_collator

# instantiate trainer
trainer = Seq2SeqTrainer(
    model=model,
    tokenizer=processor.feature_extractor,
    args=training_args,
    compute_metrics=compute_metrics,
    train_dataset=train_dataset,
    eval_dataset=eval_dataset,
    data_collator=default_data_collator,
)

Using cuda_amp half precision backend


In [12]:
trainer.train()

***** Running training *****
  Num examples = 18
  Num Epochs = 16
  Instantaneous batch size per device = 1
  Total train batch size (w. parallel, distributed & accumulation) = 1
  Gradient Accumulation steps = 1
  Total optimization steps = 288


Epoch,Training Loss,Validation Loss,Character Error Rate
1,No log,4.881782,0.809249
2,No log,3.932956,0.728324
3,No log,3.328248,0.630058
4,No log,2.075616,0.549133
5,No log,2.274999,0.479769
6,No log,2.693207,0.462428
7,No log,1.408315,0.358382
8,No log,2.297039,0.381503
9,No log,1.063546,0.381503
10,No log,1.191752,0.213873


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['wayway environment', ' environment environment environment', ' environment environment environment environment', 'way environment', 'way', 'way environment environment environment'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['add opinion opinion company company company company', 'sponsationstream session sign sign sign sign', 'sponsategy experience importance importance importance', 'stratory intention responsibility intention intention time', 'personer responsibility responsibility notice room room', 'handhandhand responsibility session'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addination opinion importance advice advice', 'discessoressor game game game right sign', 'supermarket chance chance chance chance', 'seREAD knowledge knowledge way way time', 'POSibility time agreement agreement', 'manufact manufacturer change change change'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['add opinion application offer environment commitment competition competition', 'disc opinion game process interest sign region mall', 'super chant reaction research performance impression impression impression', 'stARD knowledge knowledge wind responsibility time penalty', 'presibility meet agreement combination organization force mob', ' manufacturer game charity plan no'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addination application offer offer environment competition competition', 'discOSED game language language sign sign sign mall', 'superishment chance chance chance year reference impression', 'standardstrANDer knowledge wind time year', 'ppANY agreement organization morphology', 'manufacturerurer game charity plan conditioning resolution'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition opinion difference trade environment', 'discussion dimension language interest soil', 'supermemark contractendant importance impression', 'standardhand knowledge wind intention penalty', 'personality answer agreement foundation holding', 'manufacturerurerurer category machine resolution'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition opinion difference trade environment composition', 'discussion guide language interest right region stall', 'supermarket chance chance chance feed back impression', 'stander boarder knowledge winder', 'personality meet argument foundation notice room time', 'manufacturer car page charity plano noBUBION'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition opinion trade environment composition', 'discOSED language interest impression mall', 'super chant chance research performance impression', 'SAND knowledge trade theory term line penalty', 'personality organization foundation force mob', 'manufacturer care charity plan no obligation'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


[' addition application offer trade environment composition', 'discOSED language interest intention mall', 'supermarket chance tract feedback period performance impression', 'hand knowledge trade winner wind wind responsibility time', 'personality manner organization form notice holding', 'manufacturer care page charity plano nocluding'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment composition', 'discussion guide language interest right right region', 'super chant chance insert performance impression', 'standard knowledge trade video theory term time penalty', 'personality meat argument foundation force nothing money', 'manufact manufacturer care page charity plan no soil'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment competition', 'discussion guidance language interest light region soil', 'supermarket chance insert feedback impression', 'standard knowledge trainer video theory term time penalty', 'POSality meat argument foundation force nothing honey', 'manufacturer care page charity plan no soil'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment competition', 'discessor guidance language interest right region mall', 'supermarket chance insert feedback preference impression', 'stARD knowledge trainer video theory term time penalty', 'posality meat argument foundation force nothing honey', 'manufacturer care page charity plan no obligation'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment competition', 'discessor guidance language interest right region soil', 'supermarket chance insert feedback impression', 'standard knowledge trainer video theory term time penalty', 'personality meat argument foundation nothing honey', 'manufacturer care page charity plan no obligation'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment competition', 'discussion guidance language interest light region mall', 'supermarket chance insert feedback preference impression', 'standard knowledge trainer video theory term time penalty', 'personality meat argument foundation force nothing honey', 'manufacturer care page charity plan no obligation'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment competition', 'discussion guidance language interest light region mall su', 'supermarket chance insert feedback preference impression', 'standard knowledge trainer video theory term time penalty', 'personality meat argument foundation force nothing honey', 'manufacturer care page charity plan no obligation'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']


***** Running Evaluation *****
  Num examples = 6
  Batch size = 1


['addition application coffee trade environment competition', 'discussion guidance language interest light region mall su', 'supermarket chance insert feedback preference impression', 'standard knowledge trainer video theory term time penalty', 'personality meat argument foundation force nothing honey', 'manufacturer care page charity plan noHD election'] ['addition application coffee trade environment competition', 'discussion guidance language interest light region mall situation', 'supermarket chance insect feedback preference impression', 'standard knowledge trainer video theory term line penalty', 'possibility meat argument foundation force nothing honey', 'manufacturer care page charity piano nothing resolution']




Training completed. Do not forget to share your model on huggingface.co/models =)




TrainOutput(global_step=288, training_loss=1.3577116860283747, metrics={'train_runtime': 511.9958, 'train_samples_per_second': 0.563, 'train_steps_per_second': 0.563, 'total_flos': 2.155061350068388e+17, 'train_loss': 1.3577116860283747, 'epoch': 16.0})

In [15]:
save_model = False
if save_model:
    model.save_pretrained("my-trocr-model", from_pt=True)
    processor.save_pretrained("my-trocr-model", from_pt=True)

Configuration saved in my-trocr-model/config.json
Model weights saved in my-trocr-model/pytorch_model.bin
Feature extractor saved in my-trocr-model/preprocessor_config.json
tokenizer config file saved in my-trocr-model/tokenizer_config.json
Special tokens file saved in my-trocr-model/special_tokens_map.json
