# Train T5 for ocr postcorrection

Sources
* https://huggingface.co/docs/transformers/tasks/summarization

## Data

* In: sequence with ocr errors
* Out: sequence without ocr errors
* Task prefix: `correct: `
    * Not sure this is necessary...

In [None]:
from google.colab import drive
drive.mount('/mntDrive')

In [None]:
! git clone https://github.com/jvdzwaan/ocrpostcorrection-notebooks.git

In [None]:
%cd /content/ocrpostcorrection-notebooks

In [None]:
! git checkout mT5-experiment

In [None]:
! pip install poetry

In [None]:
# Configure poetry to create virtual environments in the project folder
! poetry config virtualenvs.in-project true

In [None]:
! poetry install --no-ansi

In [None]:
# Add poetry virtual environment to python path so that all installed dependencies can be found by the python interpreter

import sys
sys.path.append("/content/ocrpostcorrection-notebooks/.venv/lib/python3.10/site-packages")

In [None]:
! cp /mntDrive/MyDrive/ocrpostcorrection-config/config.local .dvc/config.local

In [None]:
! poetry run dvc pull -r gdrive

In [None]:
! poetry run dvc checkout

In [None]:
from pathlib import Path

import pandas as pd
from datasets import Dataset, DatasetDict
from loguru import logger
from ocrpostcorrection.icdar_data import get_intermediate_data

In [None]:
# local
raw_dataset = Path('..')/'data'/'raw'/'ICDAR2019_POCR_competition_dataset.zip'

train_split = Path('..')/'data'/'splits'/'train.csv'
val_split = Path('..')/'data'/'splits'/'val.csv'
test_split = Path('..')/'data'/'splits'/'test.csv'

In [None]:
# colab
raw_dataset = Path('data')/'raw'/'ICDAR2019_POCR_competition_dataset.zip'

train_split = Path('data')/'splits'/'train.csv'
val_split = Path('data')/'splits'/'val.csv'
test_split = Path('data')/'splits'/'test.csv'

In [None]:
data, _, data_test, _ = get_intermediate_data(raw_dataset)

10it [00:18,  1.85s/it]
10it [00:04,  2.02it/s]


In [None]:
X_train = pd.read_csv(train_split, index_col=0)
X_val = pd.read_csv(val_split, index_col=0)
X_test = pd.read_csv(test_split, index_col=0)

In [None]:
from typing import List, Dict, Tuple
from tqdm.notebook import tqdm

import edlib
from ocrpostcorrection.icdar_data import window, normalized_ed, Text


def _process_sequence(
    key: str,
    i: int,
    res,
    ocr_sents: List[str],
    gs_sents: List[str],
    keys: List[str],
    start_tokens: List[int],
    scores: List[float],
    languages: List[str],
) -> Tuple[
    List[str], List[str], List[str], List[int], List[float], List[str]
]:
    ocr = [t.ocr for t in res]
    gs = []
    for t in res:
        if t.gs != "":
            gs.append(t.gs)
    ocr_str = " ".join(ocr)
    gs_str = " ".join(gs)
    ed = edlib.align(ocr_str, gs_str)["editDistance"]
    score = normalized_ed(ed, ocr_str, gs_str)

    if len(ocr_str) > 0 and len(gs_str) > 0:
        ocr_sents.append(ocr_str)
        gs_sents.append(gs_str)
        keys.append(key)
        start_tokens.append(i)
        scores.append(score)
        languages.append(key[:2])
    else:
        logger.info(f'Empty sample for text "{key}"')
        logger.info(f"ocr_str: {ocr_str}")
        logger.info(f"gs_str: {gs_str}")
        logger.info(f"start token: {i}")

    return (ocr_sents, gs_sents, keys, start_tokens, scores, languages)


def generate_sentences(
    df: pd.DataFrame, data: Dict[str, Text], size: int = 15, step: int = 10
) -> pd.DataFrame:
    """Generate sequences of a certain length and possible overlap"""
    ocr_sents: List[str] = []
    gs_sents: List[str] = []
    keys: List[str] = []
    start_tokens: List[int] = []
    scores: List[float] = []
    languages: List[str] = []

    for _, row in tqdm(df.iterrows()):
        key = row.file_name
        tokens = data[key].input_tokens

        # print(len(tokens))
        # print(key)
        for i, res in enumerate(window(tokens, size=size)):
            if i % step == 0:
                (
                    ocr_sents,
                    gs_sents,
                    keys,
                    start_tokens,
                    scores,
                    languages,
                ) = _process_sequence(
                    key, i, res, ocr_sents, gs_sents, keys, start_tokens, scores, languages
                )
        # Add final sequence
        (ocr_sents, gs_sents, keys, start_tokens, scores, languages) = _process_sequence(
            key, i, res, ocr_sents, gs_sents, keys, start_tokens, scores, languages
        )

    output = pd.DataFrame(
        {
            "key": keys,
            "start_token_id": start_tokens,
            "score": scores,
            "ocr": ocr_sents,
            "gs": gs_sents,
            "language": languages,
        }
    )

    # Adding the final sequence may lead to duplicate rows. Remove those
    output.drop_duplicates(
        subset=["key", "start_token_id"], keep="first", inplace=True, ignore_index=True
    )

    return output

In [None]:
size = 25
step = 25

In [None]:
logger.info(f"Generating sentences (size: {size}, step: {step})")

train_data = generate_sentences(X_train, data, size=size, step=step)
val_data = generate_sentences(X_val, data, size=size, step=step)
test_data = generate_sentences(X_test, data_test, size=size, step=size)

num_train = train_data.shape[0]
num_val = val_data.shape[0]
num_test = test_data.shape[0]
logger.info(f"# samples train: {num_train}, val: {num_val}, test: {num_test})")

2023-12-24 20:13:28.993 | INFO     | __main__:<module>:1 - Generating sentences (size: 25, step: 25)


0it [00:00, ?it/s]

2023-12-24 20:13:29.025 | INFO     | __main__:_process_sequence:39 - Empty sample for text "FR/FR3/883.txt"
2023-12-24 20:13:29.025 | INFO     | __main__:_process_sequence:40 - ocr_str: # mmm mm mm mm mm mm mm mm mm mm mm mm mm mm mmm mm mm mm mm mm mm mmm mm mm
2023-12-24 20:13:29.026 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-24 20:13:29.026 | INFO     | __main__:_process_sequence:42 - start token: 50
2023-12-24 20:13:29.090 | INFO     | __main__:_process_sequence:39 - Empty sample for text "FR/FR1/904.txt"
2023-12-24 20:13:29.093 | INFO     | __main__:_process_sequence:40 - ocr_str: Philippus naces lit regie dengnitatis, chemencam e ilionlibenter extenimus et seiu eorum liberali promonemus affectu, qui not serviciorm exnbicine grata pervenint et virtutum comis suffragis,
2023-12-24 20:13:29.096 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-24 20:13:29.098 | INFO     | __main__:_process_sequence:42 - start token: 0
2023-12-24 20:13:29.100 | INFO    

0it [00:00, ?it/s]

2023-12-24 20:13:37.363 | INFO     | __main__:_process_sequence:39 - Empty sample for text "FR/FR3/320.txt"
2023-12-24 20:13:37.364 | INFO     | __main__:_process_sequence:40 - ocr_str: * ** vw "*•* ** ama mm ma* mm mm* «m mm nm mm *• mm mm mm cms -mm mtn wm. oms w& *•
2023-12-24 20:13:37.364 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-24 20:13:37.365 | INFO     | __main__:_process_sequence:42 - start token: 250
2023-12-24 20:13:37.399 | INFO     | __main__:_process_sequence:39 - Empty sample for text "NL/NL1/51.txt"
2023-12-24 20:13:37.405 | INFO     | __main__:_process_sequence:40 - ocr_str: Bar. d’) Doorn Utrecht 1865 BeECK VOLLENHOVEN. (H. VAN) Amsterdam Noordholland 1865 Lid der Huishoudelijke Commissie. Beerenbroek. (L. F. H.) Roermond Limburg 1865 Beken Pasteel.
2023-12-24 20:13:37.414 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-24 20:13:37.415 | INFO     | __main__:_process_sequence:42 - start token: 50
2023-12-24 20:13:37.416 | INFO     | __

0it [00:00, ?it/s]

2023-12-24 20:13:38.257 | INFO     | __main__:_process_sequence:39 - Empty sample for text "BG/BG1/0.txt"
2023-12-24 20:13:38.258 | INFO     | __main__:_process_sequence:40 - ocr_str: заедно съ коня си, бп.тъ разнесенъ на пранета отъ неприятелската граната, казала: „Това е Божия воля нашия животъ е въ Божиите ржце, но азъ съмъ
2023-12-24 20:13:38.258 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-24 20:13:38.259 | INFO     | __main__:_process_sequence:42 - start token: 25
2023-12-24 20:13:38.259 | INFO     | __main__:_process_sequence:39 - Empty sample for text "BG/BG1/0.txt"
2023-12-24 20:13:38.259 | INFO     | __main__:_process_sequence:40 - ocr_str: благодарна на Спасителя, че този ми сннъ, тъй сжщо както и другите ми четире синове, можа да бжде нолезенъ за страната си, —- да се
2023-12-24 20:13:38.260 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-24 20:13:38.260 | INFO     | __main__:_process_sequence:42 - start token: 50
2023-12-24 20:13:38.261 | INF

In [None]:
train_data

Unnamed: 0,key,start_token_id,score,ocr,gs,language
0,DE/DE4/204.txt,0,0.111111,annis ad nauigandum non—erant ufi. Has reficie...,annis ad nauigandum non erant uſi. Has reficie...,DE
1,DE/DE4/204.txt,25,0.089109,receptus dare— fe Torius ferebat: milites aded...,receptus dare-ſe Torius ferebat: milites adeò ...,DE
2,DE/DE4/204.txt,50,0.110000,"Quibus literis acceptis, infolenti periclitare...","Quibus literis acceptis, inſolentipericlitaren...",DE
3,DE/DE4/204.txt,75,0.129213,"& ruinis effet deformata,‚ciues fuos primum ad...","& ruinis eſſet deformata, ciues ſuos primùm ad...",DE
4,DE/DE4/204.txt,100,0.182796,fe comotaretur: quéE Parzxtonio‚uel ä dextra a...,"ſe comoraretur: quẽParætonio, uel à dextra ab ...",DE
...,...,...,...,...,...,...
117748,DE/DE3/3649.txt,100,0.181818,"denen Bolfern. gedrucft rerden, mwarum nicht i...","Voͤlkern.gedruckt werden, warum nicht in Bayer...",DE
117749,DE/DE3/3649.txt,125,0.166667,"der Dauer biefer ben und gedendt wird, wenn er...","der Dauer dieſerben und gedruckt wird, wenn er...",DE
117750,DE/DE3/3649.txt,150,0.209945,"Siteratur befte— Sandrechts, fünftig auffßören...","Literatur beſte Landrechts, kuͤnftig aufhoͤren...",DE
117751,DE/DE3/3649.txt,175,0.200000,von 1785 bis 1813 (ein längerer Beit— männer b...,von 1785 bis 1815 (ein laͤngerer Zeit maͤnner ...,DE


In [None]:
train_data.sample(5)

Unnamed: 0,key,start_token_id,score,ocr,gs,language
97497,DE/DE3/5551.txt,100,0.251656,"bdie fi Sbnen der WGreibeit zu Brauen, da8ß ar...","die ſich Jhnender Freiheit zu brauen, das arom...",DE
77721,FR/FR1/701.txt,25,0.046243,"quadium et quaciens in regno nostro existent, ...","quamdiu et quociens in regno nostro existent, ...",FR
107611,DE/DE5/75.txt,200,0.23622,Simon. unfer ere wirt vil groffer Jen Jes Fepf...,Simon.vnſer ere wirt vil groſſer den des keyſe...,DE
29428,DE/DE3/5342.txt,75,0.186335,batte ja fogar im &erfer davon bufdichte Afuge...,hatte ja ſogar im Kerker davonbuſchichte Augen...,DE
114404,DE/DE2/46.txt,550,0.116129,man in Der ©reitfopßschen SJtieDerlage 6 Doppe...,man in der Breitkopﬁschen Niederlage 6 Doppel⸗...,DE


In [None]:
def add_len(text):
    return len(text)

for df in (train_data, test_data, val_data):
    df["len_ocr"]  = df["ocr"].apply(add_len)
    df["len_gs"]  = df["gs"].apply(add_len)

In [None]:
for df in (train_data, test_data, val_data):
    print(f"ocr: {df.len_ocr.describe()}")
    print(f"gs: {df.len_gs.describe()}")

ocr: count    117753.000000
mean        157.385425
std          23.178839
min          10.000000
25%         142.000000
50%         154.000000
75%         170.000000
max         350.000000
Name: len_ocr, dtype: float64
gs: count    117753.000000
mean        154.779946
std          37.659245
min           1.000000
25%         139.000000
50%         151.000000
75%         167.000000
max        5593.000000
Name: len_gs, dtype: float64
ocr: count    34698.000000
mean       156.618854
std         23.233312
min         23.000000
25%        141.000000
50%        154.000000
75%        168.000000
max        338.000000
Name: len_ocr, dtype: float64
gs: count    34698.000000
mean       154.423137
std         47.574638
min          1.000000
25%        139.000000
50%        151.000000
75%        166.000000
max       2718.000000
Name: len_gs, dtype: float64
ocr: count    13017.000000
mean       157.875163
std         23.141082
min         60.000000
25%        143.000000
50%        155.000000
75%    

In [None]:
# logger.info(f"Filtering train and val based on maximum edit distance of {max_ed}")
# train_data = train_data[train_data.score < max_ed]
# val_data = val_data[val_data.score < max_ed]

In [None]:
for df in (train_data, val_data, test_data):
    df.drop(columns=["score"], inplace=True)

dataset = DatasetDict(
    {
        "train": Dataset.from_pandas(train_data),
        "val": Dataset.from_pandas(val_data),
        "test": Dataset.from_pandas(test_data),
    }
)

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['key', 'start_token_id', 'ocr', 'gs', 'language', 'len_ocr', 'len_gs'],
        num_rows: 117753
    })
    val: Dataset({
        features: ['key', 'start_token_id', 'ocr', 'gs', 'language', 'len_ocr', 'len_gs'],
        num_rows: 13017
    })
    test: Dataset({
        features: ['key', 'start_token_id', 'ocr', 'gs', 'language', 'len_ocr', 'len_gs'],
        num_rows: 34698
    })
})

In [None]:
from transformers import AutoTokenizer

checkpoint = "google/byt5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)

In [None]:
max_len = 200

def preprocess_function(examples):
    ocr_texts = [doc for doc in examples["ocr"]]
    model_inputs = tokenizer(ocr_texts, max_length=max_len, truncation=True)

    gs_texts = [doc for doc in examples["gs"]]
    labels = tokenizer(text_target=gs_texts, max_length=max_len, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/117753 [00:00<?, ? examples/s]

Map:   0%|          | 0/13017 [00:00<?, ? examples/s]

Map:   0%|          | 0/34698 [00:00<?, ? examples/s]

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
from evaluate import load
cer = load("cer")

In [None]:
cer_score = cer.compute(predictions=val_data.ocr.to_list(), references=val_data.gs.to_list())
print(cer_score)

0.21237264855586288


In [None]:
wer = load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [None]:
wer_score = wer.compute(predictions=val_data.ocr.to_list(), references=val_data.gs.to_list())
print(wer_score)

0.635085134039459


In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

Downloading pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="my_awesome_billsum_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    # fp16=True,
    push_to_hub=False,
    load_best_model_at_end="epoch",
    save_strategy="epoch",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["val"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

The following columns in the training set don't have a corresponding argument in `T5ForConditionalGeneration.forward` and have been ignored: ocr, language, len_gs, len_ocr, start_token_id, gs, key. If ocr, language, len_gs, len_ocr, start_token_id, gs, key are not expected by `T5ForConditionalGeneration.forward`,  you can safely ignore this message.
***** Running training *****
  Num examples = 117753
  Num Epochs = 4
  Instantaneous batch size per device = 16
  Total train batch size (w. parallel, distributed & accumulation) = 16
  Gradient Accumulation steps = 1
  Total optimization steps = 29440
  Number of trainable parameters = 299637760


  0%|          | 0/29440 [00:00<?, ?it/s]

KeyboardInterrupt: 