In [None]:
# |default_exp error_correction_t5

In [None]:
# | export
from typing import Dict

import pandas as pd

In [None]:
import os
from pathlib import Path

from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

from ocrpostcorrection.error_correction import get_tokens_with_OCR_mistakes, get_context_for_dataset
from ocrpostcorrection.icdar_data import generate_data

In [None]:
# | hide

data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"

data, md = generate_data(data_dir)

val_files = ['fr/fr_sample/2.txt']

tdata = get_tokens_with_OCR_mistakes(data, data, val_files)
tdata = get_context_for_dataset(data, tdata, 20)
tdata.drop_duplicates(subset=["ocr", "gs", "dataset", "language", "context_before", "context_after"], inplace=True)
tdata.reset_index(drop=True, inplace=True)

print(tdata.shape)
tdata.head()

2it [00:00, 1048.31it/s]
100%|██████████| 4/4 [00:00<00:00, 865.97it/s]

(66, 15)





Unnamed: 0,ocr,gs,ocr_aligned,gs_aligned,start,len_ocr,key,language,subset,dataset,len_gs,diff,context_before,context_after,len_mistake_in_context
0,In,,In,##,0,2,en/eng_sample/1.txt,en,eng_sample,test,0,2,,"botany, a troe is a",22
1,troe,tree,troe,tree,13,4,en/eng_sample/1.txt,en,eng_sample,test,4,0,"In botany, a",is a peremial plant,37
2,peremial,perennial,perem@ial,perennial,23,8,en/eng_sample/1.txt,en,eng_sample,test,9,-1,"botany, a troe is a",plant with an eLngated,51
3,eLngated,elongated,eL@ngated,elongated,46,8,en/eng_sample/1.txt,en,eng_sample,test,9,-1,peremial plant with an,"stein, or trunk,",48
4,"stein,","stem,","stein,","stem@,",55,6,en/eng_sample/1.txt,en,eng_sample,test,5,1,plant with an eLngated,"or trunk, suppor ing",50


In [None]:
dataset = DatasetDict(
        {
            "train": Dataset.from_pandas(tdata.query('dataset == "train"')),
            "val": Dataset.from_pandas(tdata.query('dataset == "val"')),
            "test": Dataset.from_pandas(tdata.query('dataset == "test"')),
        }
    )
dataset['train'][1]

{'ocr': 'troe',
 'gs': 'tree',
 'ocr_aligned': 'troe',
 'gs_aligned': 'tree',
 'start': 13,
 'len_ocr': 4,
 'key': 'en/eng_sample/1.txt',
 'language': 'en',
 'subset': 'eng_sample',
 'dataset': 'train',
 'len_gs': 4,
 'diff': 0,
 'context_before': 'In botany, a ',
 'context_after': ' is a peremial plant',
 'len_mistake_in_context': 37,
 '__index_level_0__': 14}

In [None]:
# | export

def filter_max_len(example: Dict, max_len: int):
    if example["len_ocr"] <= max_len and example["len_gs"] <= max_len:
        return True
    return False

In [None]:
# | export

def filter_len_ocr_mistake_in_context(data: pd.DataFrame, context_offset: int) -> pd.DataFrame:
    data = data.query(f"len_ocr_mistake_in_context <= {context_offset * 10}").copy()
    return data

In [None]:
model_name = "google/byt5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
# | export

def preprocess_function(examples, tokenizer, add_task_prefix: bool=False, context_marker: str=""):
    if context_marker:
        input = [
            f"{before}<{context_marker}>{ocr_str}</{context_marker}>{after}"
            for before, ocr_str, after in zip(examples["context_before"], examples["ocr"], examples["context_after"])
        ]
    else:
        input = examples["ocr"]

    if add_task_prefix:
        input = [f"{language}: {ocr_str}" for ocr_str, language in zip(input, examples['language'])]

    model_inputs = tokenizer(input)

    labels = tokenizer(text_target=examples["gs"])

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer}, batched=True
)

Map:   0%|          | 0/23 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/33 [00:00<?, ? examples/s]

In [None]:
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])

'troe</s>'

In [None]:
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer, "add_task_prefix": True}, batched=True
)

Map:   0%|          | 0/23 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/33 [00:00<?, ? examples/s]

In [None]:
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])

'en: troe</s>'

In [None]:
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer, "add_task_prefix": True, "context_marker": "mistake"}, batched=True
)

Map:   0%|          | 0/23 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/33 [00:00<?, ? examples/s]

In [None]:
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])

'en: In botany, a <mistake>troe</mistake> is a peremial plant</s>'

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()