In [None]:
# |default_exp error_correction_t5

In [None]:
# | export
from typing import Dict

In [None]:
import os
from pathlib import Path

from datasets import Dataset, DatasetDict
from transformers import AutoTokenizer

from ocrpostcorrection.error_correction import get_tokens_with_OCR_mistakes
from ocrpostcorrection.icdar_data import generate_data

In [None]:
# | hide

data_dir = Path(os.getcwd()) / "data" / "dataset_training_sample"

data, md = generate_data(data_dir)

val_files = ['fr/fr_sample/2.txt']

tdata = get_tokens_with_OCR_mistakes(data, data, val_files)
tdata.drop_duplicates(subset=["ocr", "gs", "dataset"], inplace=True)
tdata.reset_index(drop=True, inplace=True)

print(tdata.shape)
tdata.head()

2it [00:00, 1520.78it/s]

(61, 12)





Unnamed: 0,ocr,gs,ocr_aligned,gs_aligned,start,len_ocr,key,language,subset,dataset,len_gs,diff
0,In,,In,##,0,2,en/eng_sample/1.txt,en,eng_sample,test,0,2
1,troe,tree,troe,tree,13,4,en/eng_sample/1.txt,en,eng_sample,test,4,0
2,peremial,perennial,perem@ial,perennial,23,8,en/eng_sample/1.txt,en,eng_sample,test,9,-1
3,eLngated,elongated,eL@ngated,elongated,46,8,en/eng_sample/1.txt,en,eng_sample,test,9,-1
4,"stein,","stem,","stein,","stem@,",55,6,en/eng_sample/1.txt,en,eng_sample,test,5,1


In [None]:
dataset = DatasetDict(
        {
            "train": Dataset.from_pandas(tdata.query('dataset == "train"')),
            "val": Dataset.from_pandas(tdata.query('dataset == "val"')),
            "test": Dataset.from_pandas(tdata.query('dataset == "test"')),
        }
    )
dataset['train'][1]

{'ocr': 'troe',
 'gs': 'tree',
 'ocr_aligned': 'troe',
 'gs_aligned': 'tree',
 'start': 13,
 'len_ocr': 4,
 'key': 'en/eng_sample/1.txt',
 'language': 'en',
 'subset': 'eng_sample',
 'dataset': 'train',
 'len_gs': 4,
 'diff': 0,
 '__index_level_0__': 31}

In [None]:
# | export

def filter_max_len(example: Dict, max_len: int):
    if example["len_ocr"] <= max_len and example["len_gs"] <= max_len:
        return True
    return False

In [None]:
model_name = "google/byt5-small"

tokenizer = AutoTokenizer.from_pretrained(model_name)

In [None]:
# | export

def preprocess_function(examples, tokenizer, add_task_prefix: bool=False):
    input = examples["ocr"]
    if add_task_prefix:
        input = [f"{language}: {ocr_str}" for ocr_str, language in zip(examples["ocr"], examples['language'])]

    model_inputs = tokenizer(input)

    labels = tokenizer(text_target=examples["gs"])

    model_inputs["labels"] = labels["input_ids"]
    return model_inputs

In [None]:
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer}, batched=True
)

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

In [None]:
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])

'troe</s>'

In [None]:
tokenized_dataset = dataset.map(
    preprocess_function, fn_kwargs={"tokenizer": tokenizer, "add_task_prefix": True}, batched=True
)

Map:   0%|          | 0/21 [00:00<?, ? examples/s]

Map:   0%|          | 0/10 [00:00<?, ? examples/s]

Map:   0%|          | 0/30 [00:00<?, ? examples/s]

In [None]:
tokenizer.decode(tokenized_dataset['train'][1]['input_ids'])

'en: troe</s>'

In [None]:
# | hide
import nbdev

nbdev.nbdev_export()