# Train T5 for ocr postcorrection

Sources
* https://huggingface.co/docs/transformers/tasks/summarization

## Data

* In: sequence with ocr errors
* Out: sequence without ocr errors
* Task prefix: `correct: `
    * Not sure this is necessary...

In [None]:
from google.colab import drive
drive.mount('/mntDrive')

In [None]:
! git clone https://github.com/jvdzwaan/ocrpostcorrection-notebooks.git

In [None]:
%cd /content/ocrpostcorrection-notebooks

In [None]:
! pip install poetry

In [None]:
# Configure poetry to create virtual environments in the project folder
! poetry config virtualenvs.in-project true

In [None]:
! poetry install --no-ansi

In [None]:
# Add poetry virtual environment to python path so that all installed dependencies can be found by the python interpreter

import sys
sys.path.append("/content/ocrpostcorrection-notebooks/.venv/lib/python3.8/site-packages")

In [None]:
! cp /mntDrive/MyDrive/ocrpostcorrection-config/config.local .dvc/config.local

In [None]:
! poetry run dvc pull -r gdrive

In [None]:
from pathlib import Path

import pandas as pd
from datasets import Dataset, DatasetDict
from loguru import logger
from ocrpostcorrection.icdar_data import get_intermediate_data

In [None]:
raw_dataset = Path('..')/'data'/'raw'/'ICDAR2019_POCR_competition_dataset.zip'

train_split = Path('..')/'data'/'splits'/'train.csv'
val_split = Path('..')/'data'/'splits'/'val.csv'
test_split = Path('..')/'data'/'splits'/'test.csv'

In [None]:
data, _, data_test, _ = get_intermediate_data(raw_dataset)

10it [00:30,  3.09s/it]
10it [00:09,  1.08it/s]


In [None]:
X_train = pd.read_csv(train_split, index_col=0)
X_val = pd.read_csv(val_split, index_col=0)
X_test = pd.read_csv(test_split, index_col=0)

In [None]:
from typing import List, Dict, Tuple
from tqdm.notebook import tqdm

import edlib
from ocrpostcorrection.icdar_data import window, normalized_ed, Text


def _process_sequence(
    key: str,
    i: int,
    res,
    ocr_sents: List[str],
    gs_sents: List[str],
    keys: List[str],
    start_tokens: List[int],
    scores: List[float],
    languages: List[str],
) -> Tuple[
    List[str], List[str], List[str], List[int], List[float], List[str]
]:
    ocr = [t.ocr for t in res]
    gs = []
    for t in res:
        if t.gs != "":
            gs.append(t.gs)
    ocr_str = " ".join(ocr)
    gs_str = " ".join(gs)
    ed = edlib.align(ocr_str, gs_str)["editDistance"]
    score = normalized_ed(ed, ocr_str, gs_str)

    if len(ocr_str) > 0 and len(gs_str) > 0:
        ocr_sents.append(ocr_str)
        gs_sents.append(gs_str)
        keys.append(key)
        start_tokens.append(i)
        scores.append(score)
        languages.append(key[:2])
    else:
        logger.info(f'Empty sample for text "{key}"')
        logger.info(f"ocr_str: {ocr_str}")
        logger.info(f"gs_str: {gs_str}")
        logger.info(f"start token: {i}")

    return (ocr_sents, gs_sents, keys, start_tokens, scores, languages)


def generate_sentences(
    df: pd.DataFrame, data: Dict[str, Text], size: int = 15, step: int = 10
) -> pd.DataFrame:
    """Generate sequences of a certain length and possible overlap"""
    ocr_sents: List[str] = []
    gs_sents: List[str] = []
    keys: List[str] = []
    start_tokens: List[int] = []
    scores: List[float] = []
    languages: List[str] = []

    for _, row in tqdm(df.iterrows()):
        key = row.file_name
        tokens = data[key].input_tokens

        # print(len(tokens))
        # print(key)
        for i, res in enumerate(window(tokens, size=size)):
            if i % step == 0:
                (
                    ocr_sents,
                    gs_sents,
                    keys,
                    start_tokens,
                    scores,
                    languages,
                ) = _process_sequence(
                    key, i, res, ocr_sents, gs_sents, keys, start_tokens, scores, languages
                )
        # Add final sequence
        (ocr_sents, gs_sents, keys, start_tokens, scores, languages) = _process_sequence(
            key, i, res, ocr_sents, gs_sents, keys, start_tokens, scores, languages
        )

    output = pd.DataFrame(
        {
            "key": keys,
            "start_token_id": start_tokens,
            "score": scores,
            "ocr": ocr_sents,
            "gs": gs_sents,
            "language": languages,
        }
    )

    # Adding the final sequence may lead to duplicate rows. Remove those
    output.drop_duplicates(
        subset=["key", "start_token_id"], keep="first", inplace=True, ignore_index=True
    )

    return output

In [None]:
size = 35
step = 35

In [None]:
logger.info(f"Generating sentences (size: {size}, step: {step})")

train_data = generate_sentences(X_train, data, size=size, step=step)
val_data = generate_sentences(X_val, data, size=size, step=step)
test_data = generate_sentences(X_test, data_test, size=size, step=size)

num_train = train_data.shape[0]
num_val = val_data.shape[0]
num_test = test_data.shape[0]
logger.info(f"# samples train: {num_train}, val: {num_val}, test: {num_test})")

2023-12-15 12:05:56.275 | INFO     | __main__:<module>:1 - Generating sentences (size: 35, step: 35)


0it [00:00, ?it/s]

2023-12-15 12:05:56.378 | INFO     | __main__:_process_sequence:39 - Empty sample for text "FR/FR1/904.txt"
2023-12-15 12:05:56.379 | INFO     | __main__:_process_sequence:40 - ocr_str: Philippus naces lit regie dengnitatis, chemencam e ilionlibenter extenimus et seiu eorum liberali promonemus affectu, qui not serviciorm exnbicine grata pervenint et virtutum comis suffragis, digne sibi venditant primia mentorum. Nomnacaque facimus aniversi, tam presentibus
2023-12-15 12:05:56.379 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-15 12:05:56.380 | INFO     | __main__:_process_sequence:42 - start token: 0
2023-12-15 12:05:56.380 | INFO     | __main__:_process_sequence:39 - Empty sample for text "FR/FR1/904.txt"
2023-12-15 12:05:56.381 | INFO     | __main__:_process_sequence:40 - ocr_str: quad futuris, pro lict Jihes Guibti, clericus et Guillelmus ejus fratis, corhibitis cipulis de Conjignes, vides et solutis extum habuisse dicantur. Nos, attendentos quod, sicut relacio fide d

0it [00:00, ?it/s]

2023-12-15 12:06:06.682 | INFO     | __main__:_process_sequence:39 - Empty sample for text "FR/FR3/320.txt"
2023-12-15 12:06:06.688 | INFO     | __main__:_process_sequence:40 - ocr_str: w# * mm mm * * ** vw "*•* ** ama mm ma* mm mm* «m mm nm mm *• mm mm mm cms -mm mtn wm. oms w& *• »%, *o o* * ***
2023-12-15 12:06:06.695 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-15 12:06:06.699 | INFO     | __main__:_process_sequence:42 - start token: 245
2023-12-15 12:06:06.746 | INFO     | __main__:_process_sequence:39 - Empty sample for text "NL/NL1/51.txt"
2023-12-15 12:06:06.746 | INFO     | __main__:_process_sequence:40 - ocr_str: van waar afgevaardigd. JAREN van aftreding. BIJZONDERHEDEN. Abi.aing van Giessenburg. (Jhr. J. D. C.C. W. Bar. d’) Doorn Utrecht 1865 BeECK VOLLENHOVEN. (H. VAN) Amsterdam Noordholland 1865 Lid der Huishoudelijke Commissie. Beerenbroek. (L. F. H.)
2023-12-15 12:06:06.747 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-15 12:06:06.747 | 

0it [00:00, ?it/s]

2023-12-15 12:06:07.887 | INFO     | __main__:_process_sequence:39 - Empty sample for text "BG/BG1/0.txt"
2023-12-15 12:06:07.887 | INFO     | __main__:_process_sequence:40 - ocr_str: граната, казала: „Това е Божия воля нашия животъ е въ Божиите ржце, но азъ съмъ благодарна на Спасителя, че този ми сннъ, тъй сжщо както и другите ми четире синове, можа да бжде нолезенъ за
2023-12-15 12:06:07.888 | INFO     | __main__:_process_sequence:41 - gs_str: 
2023-12-15 12:06:07.889 | INFO     | __main__:_process_sequence:42 - start token: 35
2023-12-15 12:06:07.889 | INFO     | __main__:_process_sequence:39 - Empty sample for text "BG/BG1/0.txt"
2023-12-15 12:06:07.890 | INFO     | __main__:_process_sequence:40 - ocr_str: по- следнпятъ (четвъртини. деяь) делегатката на Соф. Д-ство Майка отсжтствува. За причина на това отсжтвие ний знаемъ, че Г-жа Каравелова предпо лагайки, че конгрессътъ ще да трае три дни, бе отрано решила да замине за
2023-12-15 12:06:07.890 | INFO     | __main__:_process_seque

In [None]:
train_data

Unnamed: 0,key,start_token_id,score,ocr,gs,language
0,DE/DE4/204.txt,0,0.100000,annis ad nauigandum non—erant ufi. Has reficie...,annis ad nauigandum non erant uſi. Has reficie...,DE
1,DE/DE4/204.txt,35,0.100346,nomen in fcutis Non dubitanter Alexandrini cla...,nomen in ſcutisNon dubitanter Alexandrini claſ...,DE
2,DE/DE4/204.txt,70,0.151751,"uideretut,fimul ut contra fimis incédiis & rui...","uideretur, ſimul ut contraſimis incẽdiis & rui...",DE
3,DE/DE4/204.txt,105,0.137500,dextra ab infula:qug diuerfx nauig3tiones nunq...,dextra ab inſula: quę diuerſæ nauigationes nun...,DE
4,DE/DE4/204.txt,140,0.113122,"calamum obtuflorem, B 3 los qui domi quzratnum...","calamum obtuſiorem,B 3 losqui domi quærat num ...",DE
...,...,...,...,...,...,...
85762,DE/DE3/3649.txt,35,0.208661,ffarb jene 'Bemüßbungen nicht vorbergegangen w...,ſtarbjene Bemuͤhungen nicht vorhergegangen waͤ...,DE
85763,DE/DE3/3649.txt,70,0.203320,bier der größte Berlag von eloffifhen Hu— Öemi...,hier der groͤßte Verlag von claſſiſchen Au Gew...,DE
85764,DE/DE3/3649.txt,105,0.140625,nicht in Bayern Cotta's befeitigt woerden fönn...,nicht in Bayern Cotta'sbeſeitigt werden koͤnne...,DE
85765,DE/DE3/3649.txt,140,0.200820,"&Cprache und Siterati . den, daß, wenn uberbau...","Sprache und Literatur.den, daß, wenn uͤberhaup...",DE


In [None]:
train_data.sample(5)

Unnamed: 0,key,start_token_id,score,ocr,gs,language
47227,FR/FR3/815.txt,0,0.045977,MERCI et à très BIENTOT Unité U - Balance: O -...,MERCI et à très BIENTOT Unité: 0 – Balance: O ...,FR
14173,DE/DE5/115.txt,105,0.314516,attendens 4d agat iuidef. Art.i:G met?fübiecif...,attendens qͥd agatĩuidet᷑. At.i:ſʒ metꝰ ſubiec...,DE
34358,DE/DE6/206.txt,140,0.35514,Böle fes wird refolisteren» ° Diafcorides fpri...,boͤſe feuch⸗vnd reſoluieren · ¶ Diaſcorides ſp...,DE
12720,BG/BG1/43.txt,315,0.098837,"поиитахъ ж азъ. — За да сн намЬрж имание, отго...","попитахъ ѭ азъ. За да си намѣрѭ имание, отго...",BG
21629,FR/FR1/334.txt,525,0.014563,et successeurs perpetuelment toute juridicion ...,et successeurs perpetuelment toute juridicion ...,FR


In [None]:
def add_len(text):
    return len(text)

for df in (train_data, test_data, val_data):
    df["len_ocr"]  = df["ocr"].apply(add_len)
    df["len_gs"]  = df["gs"].apply(add_len)

In [None]:
for df in (train_data, test_data, val_data):
    print(f"ocr: {df.len_ocr.describe()}")
    print(f"gs: {df.len_gs.describe()}")

ocr: count    85767.000000
mean       220.624389
std         30.413322
min         10.000000
25%        201.000000
50%        217.000000
75%        236.000000
max        464.000000
Name: len_ocr, dtype: float64
gs: count    85767.000000
mean       216.818100
std         48.505311
min          1.000000
25%        197.000000
50%        212.000000
75%        232.000000
max       5593.000000
Name: len_gs, dtype: float64
ocr: count    25258.000000
mean       219.577955
std         30.525005
min         23.000000
25%        200.000000
50%        216.000000
75%        234.000000
max        510.000000
Name: len_ocr, dtype: float64
gs: count    25258.000000
mean       216.189207
std         60.897175
min          1.000000
25%        196.000000
50%        211.000000
75%        231.000000
max       3998.000000
Name: len_gs, dtype: float64
ocr: count    9474.000000
mean      221.269263
std        30.127878
min        99.000000
25%       202.000000
50%       217.000000
75%       237.000000
max     

In [None]:
# logger.info(f"Filtering train and val based on maximum edit distance of {max_ed}")
# train_data = train_data[train_data.score < max_ed]
# val_data = val_data[val_data.score < max_ed]

In [None]:
for df in (train_data, val_data, test_data):
    df.drop(columns=["score"], inplace=True)

dataset = DatasetDict(
    {
        "train": Dataset.from_pandas(train_data),
        "val": Dataset.from_pandas(val_data),
        "test": Dataset.from_pandas(test_data),
    }
)

In [None]:
dataset

DatasetDict({
    train: Dataset({
        features: ['key', 'start_token_id', 'ocr', 'gs', 'language', 'len_ocr', 'len_gs'],
        num_rows: 85767
    })
    val: Dataset({
        features: ['key', 'start_token_id', 'ocr', 'gs', 'language', 'len_ocr', 'len_gs'],
        num_rows: 9474
    })
    test: Dataset({
        features: ['key', 'start_token_id', 'ocr', 'gs', 'language', 'len_ocr', 'len_gs'],
        num_rows: 25258
    })
})

In [None]:
from transformers import AutoTokenizer

checkpoint = "google/mt5-small"
tokenizer = AutoTokenizer.from_pretrained(checkpoint)



In [None]:
def preprocess_function(examples):
    inputs = [doc for doc in examples["ocr"]]
    model_inputs = tokenizer(inputs, max_length=500, truncation=True)

    labels = tokenizer(text_target=examples["gs"], max_length=500, truncation=True)

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(preprocess_function, batched=True)

Map:   0%|          | 0/85767 [00:00<?, ? examples/s]

Map:   0%|          | 0/9474 [00:00<?, ? examples/s]

Map:   0%|          | 0/25258 [00:00<?, ? examples/s]

In [None]:
from transformers import DataCollatorForSeq2Seq

data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, model=checkpoint)

In [None]:
from evaluate import load
cer = load("cer")

In [None]:
cer_score = cer.compute(predictions=val_data.ocr.to_list(), references=val_data.gs.to_list())
print(cer_score)

0.21237264855586288


In [None]:
wer = load("wer")

Downloading builder script:   0%|          | 0.00/4.49k [00:00<?, ?B/s]

In [None]:
wer_score = wer.compute(predictions=val_data.ocr.to_list(), references=val_data.gs.to_list())
print(wer_score)

0.635085134039459


In [None]:
from transformers import AutoModelForSeq2SeqLM, Seq2SeqTrainingArguments, Seq2SeqTrainer

model = AutoModelForSeq2SeqLM.from_pretrained(checkpoint)

Downloading pytorch_model.bin:   0%|          | 0.00/1.20G [00:00<?, ?B/s]

Downloading generation_config.json:   0%|          | 0.00/147 [00:00<?, ?B/s]

In [None]:
training_args = Seq2SeqTrainingArguments(
    output_dir="my_awesome_billsum_model",
    evaluation_strategy="epoch",
    learning_rate=2e-5,
    per_device_train_batch_size=16,
    per_device_eval_batch_size=16,
    weight_decay=0.01,
    save_total_limit=3,
    num_train_epochs=4,
    predict_with_generate=True,
    fp16=True,
    push_to_hub=False,
    load_best_model_at_end="epoch",
    save_strategy="epoch",
)

trainer = Seq2SeqTrainer(
    model=model,
    args=training_args,
    train_dataset=tokenized_dataset["train"],
    eval_dataset=tokenized_dataset["test"],
    tokenizer=tokenizer,
    data_collator=data_collator,
    # compute_metrics=compute_metrics,
)

In [None]:
trainer.train()

In [None]:
train_data.head()

Unnamed: 0,key,start_token_id,ocr,gs,language
0,DE/DE4/204.txt,0,annis ad nauigandum non—erant ufi. Has reficie...,annis ad nauigandum non erant uſi. Has reficie...,DE
1,DE/DE4/204.txt,30,"milites aded fatebantur, ut Cn.Pompeii nomen i...","milites adeò fatebantur, ut Cn. Pompeii nomen ...",DE
2,DE/DE4/204.txt,60,"uidebatur,perpaucos de fum.— fionis initium na...","uidebatur, perpaucos de ſum-ſionis initium nat...",DE
3,DE/DE4/204.txt,90,"in fuga {pem falu— Pöpeii quiequam profe&um,qu...","in fuga ſpem ſalu-Põpeii quicquam profectum, q...",DE
4,DE/DE4/204.txt,120,"Häc de— |== 297 Accipe,vide an placeat. Qnare ...","Hãc de-27Accipe, vide an placeat. Quare ſic eu...",DE


In [None]:
train_data.loc[0].gs

'annis ad nauigandum non erant uſi. Has reficiebant, illas Alexandriam re-ſtium tenebatur. Neque eum cõſilium ſuum fefellit, quin hoſtes eo prælioipſis: neq; illis imminẽtibus atq; inſequẽtibus ullus in naues receptus dare-ſe Torius ferebat: milites adeò fatebantur, ut Cn. Pompeii'

In [None]:
train_data.loc[1].gs

'milites adeò fatebantur, ut Cn. Pompeii nomen in ſcutisNon dubitanter Alexandrini claſſem producunt, atque inſtruunt. In fronteauxilia, maioraq́; miſſurus exiſtimabatur. Quibus literis acceptis, inſolentipericlitarentur. Simul illud graue ac miſerum uidebatur, perpaucos de ſum-ſionis initium'

In [None]:
edlib.align(train_data.loc[0].gs, train_data.loc[1].gs, mode = "SHW")

{'editDistance': 205,
 'alphabetLength': 41,
 'locations': [(None, 227)],
 'cigar': None}

In [None]:
train_data.loc[0].gs[222: 232]

'ſe Torius '

In [None]:
edlib.getNiceAlignment(train_data.loc[0].gs, train_data.loc[1].gs)

TypeError: getNiceAlignment() takes at least 3 positional arguments (2 given)