In [None]:
#|default_exp utils

In [None]:
#| export
import re
from functools import partial
from collections import defaultdict, Counter

import numpy as np
import pandas as pd

from loguru import logger
from transformers import AutoTokenizer

In [None]:
#| hide
import os

from pathlib import Path

from datasets import Dataset

from ocrpostcorrection.icdar_data import InputToken, generate_data, generate_sentences, process_text

# Convert predictions into ICDAR output format

In [None]:
#| export
def predictions_to_labels(predictions):
    return np.argmax(predictions, axis=2)

In [None]:
# b x max seq x # classes
predictions = np.zeros((16, 10, 3))

# Always predict 1
predictions[:, :, 1] = 1

output = predictions_to_labels(predictions)

assert np.array_equal(np.ones((16, 10)), output)

In [None]:
#| hide
# shape: b x max seq x # classes
predictions = np.array([np.identity(5)])

result = predictions_to_labels(predictions)

assert np.array_equal(np.array([0, 1, 2, 3, 4]) , result[0])

In [None]:
#| export

def separate_subtoken_predictions(word_ids, preds):
    #print(len(word_ids), word_ids)
    result = defaultdict(list)
    for word_idx, p_label in zip(word_ids, preds):
        #print(word_idx, p_label)
        if word_idx is not None:
            result[word_idx].append(p_label)
    return dict(result)


In [None]:
#| hide

word_ids = [0, 0, 0, 1, 1, 1, 2, 2, 2]
preds =    [0, 0, 0, 0, 0, 1, 1, 2, 2]

token_preds = separate_subtoken_predictions(word_ids, preds)
print(token_preds)

assert token_preds == {0: [0, 0, 0], 1: [0, 0, 1], 2: [1, 2, 2]}

{0: [0, 0, 0], 1: [0, 0, 1], 2: [1, 2, 2]}


In [None]:
#| export
def merge_subtoken_predictions(subtoken_predictions):
    token_level_predictions = []
    for word_idx, preds in subtoken_predictions.items():
        token_label = 0
        c = Counter(preds)
        #print(c)
        if c[1] > 0 and c[1] >= c[2]:
            token_label = 1
        elif c[2] > 0 and c[2] >= c[1]:
            token_label = 2

        token_level_predictions.append(token_label)
    return token_level_predictions

In [None]:
#| hide
subtoken_predictions = {0: [0, 0, 0],  # 0
                        1: [1, 1, 0],  # 1
                        2: [1, 2],     # 1
                        3: [2, 2, 1],  # 2
                        4: [0, 1, 2],  # 1
                        5: [0, 1, 0]}  # 1

token_preds = merge_subtoken_predictions(subtoken_predictions)
print(token_preds)

assert [0, 1, 1, 2, 1, 1] == token_preds

[0, 1, 1, 2, 1, 1]


In [None]:
#| export
def gather_token_predictions(preds):
    """Gather potentially overlapping token predictions"""
    labels = defaultdict(list)
        
    #print(len(text.input_tokens))
    #print(preds)
    for start, lbls in preds.items():
        for i, label in enumerate(lbls):
            labels[int(start)+i].append(label)
    #print('LABELS')
    #print(labels)
    return dict(labels)

In [None]:
#| hide

token_predictions = {0: [0, 0, 0, 0, 0],
                     1: [0, 0, 0, 0, 0],
                     2: [0, 0, 0, 0, 0]}
actual = gather_token_predictions(token_predictions)
expected = {0: [0], 1: [0, 0], 2: [0, 0, 0], 3: [0, 0, 0], 4: [0, 0, 0], 5: [0, 0], 6: [0]}

assert expected == actual


In [None]:
#| export
def labels2label_str(labels):
    label_str = []

    for i, token in enumerate(labels):
        #print(i, token, labels[i])
        if 2 in labels[i]:
            label_str.append('2')
        elif 1 in labels[i]:
            label_str.append('1')
        else:
            label_str.append('0')
    label_str = ''.join(label_str)
    return label_str

In [None]:
#| hide
labels = [[0], [1], [2], [0, 0, 1], [0, 1, 2]]

label_str = labels2label_str(labels)

assert label_str == '01212'

In [None]:
#| export

def extract_icdar_output(label_str, input_tokens):
    #print(label_str, input_tokens)
    #print(len(label_str), len(input_tokens))
    text_output = {}

    # Correct use of 2 (always following a 1)
    regex = r'12*'

    for match in re.finditer(regex, label_str):
        #print(match)
        #print(match.group())
        num_tokens = len(match.group())
        #print(match.start(), len(input_tokens))
        idx = input_tokens[match.start()].start
        text_output[f'{idx}:{num_tokens}'] = {}

    # Incorrect use of 2 (following a 0) -> interpret first 2 as 1
    regex = r'02+'

    for match in re.finditer(regex, label_str):
        #print(match)
        #print(match.group())
        num_tokens = len(match.group()) - 1
        idx = input_tokens[match.start()+1].start
        text_output[f'{idx}:{num_tokens}'] = {}
    
    return text_output

In [None]:
#| hide
label_str = '1'
input_tokens = [InputToken(ocr='bal', gs='bla', start=0, len_ocr=3, label=1)]
output = extract_icdar_output(label_str, input_tokens)
assert output == {'0:1': {}}, output

In [None]:
#| hide
label_str = '01'
input_tokens = [InputToken(ocr='one', gs='one', start=0, len_ocr=3, label=0),
                InputToken(ocr='tow', gs='two', start=4, len_ocr=3, label=1)]
output = extract_icdar_output(label_str, input_tokens)
assert output == {'4:1': {}}, output

In [None]:
#| hide
label_str = '12'
input_tokens = [InputToken(ocr='one', gs='one', start=0, len_ocr=3, label=0),
                InputToken(ocr='tow', gs='two', start=4, len_ocr=3, label=1)]
output = extract_icdar_output(label_str, input_tokens)
assert output == {'0:2': {}}, output

In [None]:
#| hide
label_str = '112'
input_tokens = [InputToken(ocr='one', gs='one', start=0, len_ocr=3, label=0),
                InputToken(ocr='one', gs='one', start=4, len_ocr=3, label=0),
                InputToken(ocr='tow', gs='two', start=8, len_ocr=3, label=1)]
output = extract_icdar_output(label_str, input_tokens)
assert output == {'0:1': {}, '4:2': {}}, output

In [None]:
#| hide
label_str = '02'
input_tokens = [InputToken(ocr='one', gs='one', start=0, len_ocr=3, label=0),
                InputToken(ocr='tow', gs='two', start=4, len_ocr=3, label=1)]
output = extract_icdar_output(label_str, input_tokens)
assert output == {'4:1': {}}, output

In [None]:
#| export

def predictions2icdar_output(samples, predictions, tokenizer, data_test):
    """Convert predictions into icdar output format"""
    #print('samples', len(samples))
    #print(samples)
    #print(samples[0].keys())
    #for sample in samples:
    #    print(sample.keys()) 

    tokenized_samples = tokenizer(samples["tokens"], truncation=True, is_split_into_words=True)
    #print(samples)

    #for sample in samples:
    #    print(sample.keys())
    
    # convert predictions to labels (label_ids)
    #p = np.argmax(predictions, axis=2)
    #print(p)

    converted = defaultdict(dict)

    for i, (sample, preds) in enumerate(zip(samples, predictions)):
        #print(sample.keys())
        #label = sample['tags']
        #print(label)
        #print(len(preds), preds)
        word_ids = tokenized_samples.word_ids(batch_index=i)  # Map tokens to their respective word.
        result = separate_subtoken_predictions(word_ids, preds)
        new_tags = merge_subtoken_predictions(result)

        #print('pred', len(new_tags), new_tags)
        #print('tags', len(label), label)
        
        #print(sample)
        #print(sample['key'], sample['start_token_id'])
        converted[sample['key']][sample['start_token_id']] = new_tags
    
    output = {}
    for key, preds in converted.items():
        labels = defaultdict(list)
        #print(key)
        labels = gather_token_predictions(preds)
        label_str = labels2label_str(labels)
        try:
            text = data_test[key]
            output[key] = extract_icdar_output(label_str, text.input_tokens)
        except KeyError:
            logger.warning(f'No data found for text {key}')

    return output

In [None]:
#| hide

# Create tokenizer
bert_base_model_name = 'bert-base-multilingual-cased'
tokenizer = AutoTokenizer.from_pretrained(bert_base_model_name)

# Create data
data_dir = Path(os.getcwd())/'data'/'dataset_training_sample'
data, md = generate_data(data_dir)
sentence_df = generate_sentences(md, data, size=2, step=1)
dataset = Dataset.from_pandas(sentence_df)

# Create predictions

# b x max seq x # classes
predictions = np.zeros((len(dataset), 10, 3))

# Always predict 1
predictions[:, :, 1] = 1
predictions = predictions_to_labels(predictions)

# Generate icdar output (task 1)
actual = predictions2icdar_output(dataset, predictions, tokenizer, data)

# Expected output has an entry of lenght 1 for every input token
expected = defaultdict(dict)
for key, text in data.items():
    for token in text.input_tokens:
        expected[key][f'{token.start}:1'] = {}
        
assert expected == actual

2it [00:00, 420.69it/s]
4it [00:00, 540.76it/s]


In [None]:
#| export

def create_perfect_icdar_output(data):
    output = {}
    for key, text_obj in data.items():
        label_str = ''.join([str(t.label) for t in text_obj.input_tokens])
        output[key] = extract_icdar_output(label_str, data[key].input_tokens)
    return output

In [None]:
#| hide
in_file = Path(os.getcwd())/'data'/'example.txt'
text = process_text(in_file)

test_input = {'key': text}

actual = create_perfect_icdar_output(test_input)

# Indices (the first number) refer to the ocr input text
assert actual == {'key': {'8:1': {}, '10:1': {}}}

# Summarize icdar results

In [None]:
#| export

def aggregate_results(csv_file):
    data = pd.read_csv(csv_file, sep=';')
    data['language'] = data.File.apply(lambda x: x[:2])
    data['subset'] = data.File.apply(lambda x: x.split('/')[1])

    return data.groupby('language').mean()[['T1_Precision', 'T1_Recall', 'T1_Fmesure']]

# Development

In [None]:
#| export
def reduce_dataset(dataset, n=5):
    """Return dataset with the first n samples for each split"""
    for split in dataset.keys():
        dataset[split] = dataset[split].select(range(n))
    return dataset

In [None]:
#| hide
import nbdev; nbdev.nbdev_export()