In [None]:
#|default_exp icdar_data

In [None]:
#| export
from dataclasses import dataclass
from pathlib import Path

from loguru import logger

import edlib
import pandas as pd
from tqdm import tqdm

In [None]:
#| hide
import os

from nbdev.showdoc import *


The [ICDAR 2019 Competition on Post-OCR Text Correction 
dataset](https://sites.google.com/view/icdar2019-postcorrectionocr/dataset) 
([zenodo record](https://zenodo.org/record/3515403#.YwULoi0RoWI))
contains text files in the following format:

```
[OCR_toInput] This is a cxample...
[OCR_aligned] This is a@ cxample...
[ GS_aligned] This is an example.##
01234567890123
```

The first line contains the ocr input text. The second line contains the aligned 
ocr and the third line contains the aligned gold standard.
`@` is the aligment character and `#` represents characters in the OCR that do 
not occur in the gold standard.

For working with this data, the first 14 characters have to be removed. We also 
remove leading and trailing whitespace.

In [None]:
#| export
def remove_label_and_nl(line: str):
    return line.strip()[14:]

In [None]:
#| hide
assert remove_label_and_nl('[OCR_toInput] This is a cxample...\n') == 'This is a cxample...'

## Tokenization

Task 1 of the competition is about finding tokens with OCR mistakes. In this context a token
refers to a string between two whitespaces.

In [None]:
#| export
@dataclass
class AlignedToken:
    """Dataclass for storing aligned tokens"""
    ocr: str  # String in the OCR text
    gs: str  # String in the gold standard
    ocr_aligned: str  # String in the aligned OCR text (without aligmnent characters)
    gs_aligned: str  # String in the aligned GS text (without aligmnent characters)
    start: int  # The index of the first character in the OCR text
    len_ocr: int  # The lentgh of the OCR string

In [None]:
#| export
def tokenize_aligned(ocr_aligned: str, gs_aligned: str):
    """Get a list of AlignedTokens from the aligned OCR and GS strings"""

    ocr_cursor = 0
    start = 0

    ocr_token_chars = []
    gs_token_chars = []
    ocr_token_chars_aligned = []
    gs_token_chars_aligned = []

    tokens = []

    for ocr_aligned_char, gs_aligned_char in zip(ocr_aligned, gs_aligned):
        #print(ocr_aligned_char, gs_aligned_char, ocr_cursor)
        # The # character in ocr is not an aligment character!
        if ocr_aligned_char != '@':
            ocr_cursor += 1

        if ocr_aligned_char == ' ' and gs_aligned_char == ' ':
            #print('TOKEN')
            #print('OCR:', repr(''.join(ocr_token_chars)))
            #print(' GS:', repr(''.join(gs_token_chars)))
            #print('start:', start_char)
            #ocr_cursor += 1

            # Ignore 'tokens' without representation in the ocr text
            # (these tokens do not consist of characters)
            ocr = (''.join(ocr_token_chars)).strip()
            if ocr != '':
                tokens.append(AlignedToken(ocr,
                                          ''.join(gs_token_chars),
                                          ''.join(ocr_token_chars_aligned),
                                          ''.join(gs_token_chars_aligned),
                                          start,
                                          len(''.join(ocr_token_chars))))
            start = ocr_cursor

            ocr_token_chars = []
            gs_token_chars = []
            ocr_token_chars_aligned = []
            gs_token_chars_aligned = []
        else:
            ocr_token_chars_aligned.append(ocr_aligned_char)
            gs_token_chars_aligned.append(gs_aligned_char)
            # The # character in ocr is not an aligment character!
            if ocr_aligned_char != '@':
                ocr_token_chars.append(ocr_aligned_char)
            if gs_aligned_char != '@' and gs_aligned_char != '#':
                gs_token_chars.append(gs_aligned_char)

    # Final token (if there is one)
    ocr = (''.join(ocr_token_chars)).strip()
    if ocr != '':
        tokens.append(AlignedToken(ocr,
                                   ''.join(gs_token_chars),
                                   ''.join(ocr_token_chars_aligned),
                                   ''.join(gs_token_chars_aligned),
                                   start,
                                   len(''.join(ocr_token_chars))))

    return tokens

In [None]:
tokenize_aligned('This is a@ cxample...', 'This is an example.##')

[AlignedToken(ocr='This', gs='This', ocr_aligned='This', gs_aligned='This', start=0, len_ocr=4),
 AlignedToken(ocr='is', gs='is', ocr_aligned='is', gs_aligned='is', start=5, len_ocr=2),
 AlignedToken(ocr='a', gs='an', ocr_aligned='a@', gs_aligned='an', start=8, len_ocr=1),
 AlignedToken(ocr='cxample...', gs='example.', ocr_aligned='cxample...', gs_aligned='example.##', start=10, len_ocr=10)]

The OCR text of an AlignedToken may still consist of multiple tokens. This is the 
case when the OCR text contains one or more spaces. To make sure the (sub)tokenization
of a token is the same, no matter if it was not yet tokenized completely, 
another round of tokenization is added.

In [None]:
#| export
@dataclass
class InputToken:
    """Dataclass for the tokenization within AlignedTokens"""
    ocr: str
    gs: str
    start: int
    len_ocr: int
    label: int

In [None]:
#| export
def get_input_tokens(aligned_token: AlignedToken):
    """Tokenize an AlignedToken into subtokens and assign task 1 labels"""
    if aligned_token.ocr == aligned_token.gs:
            yield InputToken(aligned_token.ocr, aligned_token.gs,
                             aligned_token.start, len(aligned_token.ocr), 0)
    else:
        parts = aligned_token.ocr.split(' ')
        new_start = aligned_token.start
        for i, part in enumerate(parts):
            if i == 0:
                yield InputToken(part, aligned_token.gs, aligned_token.start,
                                 len(part), 1)
            else:
                yield InputToken(part, '', new_start, len(part), 2)
            new_start += len(part) + 1

In [None]:
t = AlignedToken('Major', 'Major', 'Major', 'Major', 19, 5)
print(t)

for inp_tok in get_input_tokens(t):
    print(inp_tok)

AlignedToken(ocr='Major', gs='Major', ocr_aligned='Major', gs_aligned='Major', start=19, len_ocr=5)
InputToken(ocr='Major', gs='Major', start=19, len_ocr=5, label=0)


In [None]:
#| hide
tokens = []
labels = []
gs = []

for inp_tok in get_input_tokens(t):
    tokens.append(inp_tok.ocr)
    labels.append(inp_tok.label)
    gs.append(inp_tok.gs)
    
assert tokens == ['Major']
assert labels == [0]
assert ''.join(gs) == t.gs

In [None]:
t = AlignedToken('INEVR', 'I NEVER', 'I@NEV@R', 'I NEVER', 0, 5)
print(t)

for inp_tok in get_input_tokens(t):
    print(inp_tok)

AlignedToken(ocr='INEVR', gs='I NEVER', ocr_aligned='I@NEV@R', gs_aligned='I NEVER', start=0, len_ocr=5)
InputToken(ocr='INEVR', gs='I NEVER', start=0, len_ocr=5, label=1)


In [None]:
#| hide
tokens = []
labels = []
gs = []

for inp_tok in get_input_tokens(t):
    tokens.append(inp_tok.ocr)
    labels.append(inp_tok.label)
    gs.append(inp_tok.gs)
    
assert tokens == ['INEVR']
assert labels == [1]
assert ''.join(gs) == t.gs

In [None]:
t = AlignedToken('Long ow.', 'Longhow.', 'Long ow.', 'Longhow.', 24, 8)
print(t)

for inp_tok in get_input_tokens(t):
    print(inp_tok)

AlignedToken(ocr='Long ow.', gs='Longhow.', ocr_aligned='Long ow.', gs_aligned='Longhow.', start=24, len_ocr=8)
InputToken(ocr='Long', gs='Longhow.', start=24, len_ocr=4, label=1)
InputToken(ocr='ow.', gs='', start=29, len_ocr=3, label=2)


In [None]:
#| hide
tokens = []
labels = []
gs = []

for inp_tok in get_input_tokens(t):
    tokens.append(inp_tok.ocr)
    labels.append(inp_tok.label)
    gs.append(inp_tok.gs)

assert tokens == ['Long', 'ow.']
assert labels == [1, 2]
assert ''.join(gs) == t.gs

## Process a text file

Next, we need functions for processing a text in the ICDAR data format.

In [None]:
#| export
@dataclass
class Text:
    """Dataclass for storing a text in the ICDAR data format"""
    ocr_text: str
    tokens: list
    input_tokens: list
    score: float

In [None]:
#| export
def clean(string: str):
    """Remove alignment characters from a text"""
    string = string.replace('@', '')
    string = string.replace('#', '')

    return string

In [None]:
#| export
def normalized_ed(ed: int, 
                  ocr: str, 
                  gs: str):
    """Returns the normalized editdistance"""
    score = 0.0
    l = max(len(ocr), len(gs))
    if l > 0:
        score = ed / l
    return score

In [None]:
#| export
def process_text(in_file: Path) -> Text:
    """Extract AlignedTokens, InputTokens from a text file and calculate normalized editdistance"""
    with open(in_file) as f:
        lines = f.readlines()

    # The # character in ocr input is not an aligment character, but the @
    # character is!
    ocr_input = remove_label_and_nl(lines[0]).replace('@', '')
    ocr_aligned = remove_label_and_nl(lines[1])
    gs_aligned = remove_label_and_nl(lines[2])

    #print('ocr input:', ocr_input)
    #print('ocr aligned:', ocr_aligned)
    #print('gs aligned:',gs_aligned)

    tokens = tokenize_aligned(ocr_aligned, gs_aligned)

    # Check data
    for token in tokens:
        input_token = ocr_input[token.start:token.start+token.len_ocr]
        try:
            assert token.ocr == input_token.strip()
        except AssertionError:
            logger.warning(f'OCR != aligned OCR: Text: {str(in_file)}; ocr: {repr(token.ocr)}; ocr_input: {repr(input_token)}')
            raise

    ocr = clean(ocr_aligned)
    gs = clean(gs_aligned)

    try:
        ed = edlib.align(gs, ocr)['editDistance']
        score = normalized_ed(ed, ocr, gs)
    except UnicodeEncodeError:
        logger.warning(f'UnicodeEncodeError for text {in_file}; setting score to 1')
        score = 1

    input_tokens = []
    for token in tokens:
        for inp_tok in get_input_tokens(token):
            input_tokens.append(inp_tok)

    return Text(ocr_input, tokens, input_tokens, score)


Processing the example text:

In [None]:

in_file = Path(os.getcwd())/'data'/'example.txt'
text = process_text(in_file)
text

Text(ocr_text='This is a cxample...', tokens=[AlignedToken(ocr='This', gs='This', ocr_aligned='This', gs_aligned='This', start=0, len_ocr=4), AlignedToken(ocr='is', gs='is', ocr_aligned='is', gs_aligned='is', start=5, len_ocr=2), AlignedToken(ocr='a', gs='an', ocr_aligned='a@', gs_aligned='an', start=8, len_ocr=1), AlignedToken(ocr='cxample...', gs='example.', ocr_aligned='cxample...', gs_aligned='example.##', start=10, len_ocr=10)], input_tokens=[InputToken(ocr='This', gs='This', start=0, len_ocr=4, label=0), InputToken(ocr='is', gs='is', start=5, len_ocr=2, label=0), InputToken(ocr='a', gs='an', start=8, len_ocr=1, label=1), InputToken(ocr='cxample...', gs='example.', start=10, len_ocr=10, label=1)], score=0.2)

In [None]:
#| hide
assert len(text.tokens) == 4
assert len(text.input_tokens) == 4
assert text.tokens[2].ocr == 'a'
assert text.score == 0.2

## Process the entire dataset

File structure of the ICDAR dataset
```
.
├── <data_dir>
│   ├── <language>
│   │   ├── <language (set)>1
│   │   ...
│   │   └── <language (set)>n
│   ...
...
```

In [None]:
#| export
def generate_data(in_dir: Path):
    """Process all texts in the dataset and return a dataframe with metadata"""

    data = {}

    file_languages = []
    file_names = []
    scores = []
    num_tokens = []
    num_input_tokens = []

    for language_dir in tqdm(in_dir.iterdir()):
        #print(language_dir.stem)
        language = language_dir.stem

        for text_file in language_dir.rglob('*.txt'):
            #print(text_file)
            #print(text_file.relative_to(in_dir))
            key = str(text_file.relative_to(in_dir))
            data[key] = process_text(text_file)

            file_languages.append(language)
            file_names.append(key)
            scores.append(data[key].score)
            num_tokens.append(len(data[key].tokens))
            num_input_tokens.append(len(data[key].input_tokens))
    md = pd.DataFrame({'language': file_languages,
                    'file_name': file_names,
                    'score': scores,
                    'num_tokens': num_tokens,
                    'num_input_tokens': num_input_tokens})
    return data, md


## Generate input 'sentences'

The following functions can be used to generate sequences of a certain length with possible overlap.


In [None]:
#| export
def window(iterable, size=2):
    """Given an iterable, return all subsequences of a certain size"""
    i = iter(iterable)
    win = []
    for e in range(0, size):
        try:
            win.append(next(i))
        except StopIteration:
            break
    yield win
    for e in i:
        win = win[1:] + [e]
        yield win

In [None]:
#| export
def generate_sentences(df, data, size=15, step=10):
    """Generate sequences of a certain length and possible overlap"""
    sents = []
    labels = []
    keys = []
    start_tokens = []
    scores = []
    languages = []

    for idx, row in tqdm(df.iterrows()):
        key = row.file_name
        tokens = data[key].input_tokens

        # print(len(tokens))
        # print(key)
        for i, res in enumerate(window(tokens, size=size)):
            if i % step == 0:
                ocr = [t.ocr for t in res]
                lbls = [t.label for t in res]
                gs = []
                for t in res:
                    if t.gs != '':
                        gs.append(t.gs)
                ocr_str = ' '.join(ocr)
                gs_str = ' '.join(gs)
                ed = edlib.align(ocr_str, gs_str)['editDistance']
                score = normalized_ed(ed, ocr_str, gs_str)

                if len(ocr_str) > 0:

                    sents.append(ocr)
                    labels.append(lbls)
                    keys.append(key)
                    start_tokens.append(i)
                    scores.append(score)
                    languages.append(key[:2])
                else:
                    logger.info(f'Empty sample for text "{key}"')
                    logger.info(f'ocr_str: {ocr_str}')
                    logger.info(f'start token: {i}')
    data = pd.DataFrame({
        'key': keys,
        'start_token_id': start_tokens,
        'score': scores,
        'tokens': sents,
        'tags': labels,
        'language': languages
    })

    return data

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()