In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path

from tqdm.notebook import tqdm

In [None]:
import os
os.environ['LOGURU_LEVEL'] = 'INFO'

In [None]:
import logging

from loguru import logger

class InterceptHandler(logging.Handler):
    def emit(self, record):
        # Get corresponding Loguru level if it exists
        try:
            level = logger.level(record.levelname).name
        except ValueError:
            level = record.levelno

        # Find caller from where originated the logged message
        frame, depth = logging.currentframe(), 2
        while frame.f_code.co_filename == logging.__file__:
            frame = frame.f_back
            depth += 1

        logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())

logging.basicConfig(handlers=[InterceptHandler()], level=0)

In [None]:
in_dir = Path('../../data/ICDAR2019_POCR_competition_dataset/ICDAR2019_POCR_competition_training_18M_without_Finnish')
in_dir.is_dir()

In [None]:
def remove_label_and_nl(line):
    return line.strip()[14:]

In [None]:
from dataclasses import dataclass, field
import edlib

def normalized_ed(ed, ocr, gs):
    score = 0.0
    l = max(len(ocr), len(gs))
    if l > 0:
        score = ed / l
    return score


@dataclass
class Token:
    ocr: str
    gs: str
    ocr_aligned: str
    gs_aligned: str
    start: int
    len_ocr: int
   

def tokenize_aligned(ocr_aligned, gs_aligned):

    ocr_cursor = 0
    start = 0

    ocr_token_chars = []
    gs_token_chars = []
    ocr_token_chars_aligned = []
    gs_token_chars_aligned = []

    tokens = []

    for ocr_aligned_char, gs_aligned_char in zip(ocr_aligned, gs_aligned):
        #print(ocr_aligned_char, gs_aligned_char, ocr_cursor)
        if ocr_aligned_char != '@' and ocr_aligned_char != '#':
            ocr_cursor += 1

        if ocr_aligned_char == ' ' and gs_aligned_char == ' ':
            #print('TOKEN')
            #print('OCR:', repr(''.join(ocr_token_chars)))
            #print(' GS:', repr(''.join(gs_token_chars)))
            #print('start:', start_char)
            #ocr_cursor += 1

            # Ignore 'tokens' without representation in the ocr text 
            # (these tokens do not consist of characters) 
            ocr = (''.join(ocr_token_chars)).strip()
            if ocr != '':
                tokens.append(Token(ocr, 
                                    ''.join(gs_token_chars), 
                                    ''.join(ocr_token_chars_aligned), 
                                    ''.join(gs_token_chars_aligned),
                                    start,
                                    len(''.join(ocr_token_chars))))
            start = ocr_cursor

            ocr_token_chars = []
            gs_token_chars = []
            ocr_token_chars_aligned = []
            gs_token_chars_aligned = []
        else:
            ocr_token_chars_aligned.append(ocr_aligned_char)
            gs_token_chars_aligned.append(gs_aligned_char)
            if ocr_aligned_char != '@' and ocr_aligned_char != '#':
                ocr_token_chars.append(ocr_aligned_char)
            if gs_aligned_char != '@' and gs_aligned_char != '#':
                gs_token_chars.append(gs_aligned_char)
    
    ocr = (''.join(ocr_token_chars)).strip()
    if ocr != '':
        tokens.append(Token(ocr, 
                            ''.join(gs_token_chars), 
                            ''.join(ocr_token_chars_aligned), 
                            ''.join(gs_token_chars_aligned),
                            start,
                            len(''.join(ocr_token_chars))))

    return tokens

In [None]:
def window(iterable, size=2):
    i = iter(iterable)
    win = []
    for e in range(0, size):
        try:
            win.append(next(i))
        except StopIteration:
            break
    yield win
    for e in i:
        win = win[1:] + [e]
        yield win

In [None]:
import nltk.data
import edlib

@dataclass
class Text:
    tokens: list
    score: float


def clean(string):
    string = string.replace('@', '')
    string = string.replace('#', '')

    return string


def process_text(in_file, size=15, overlap=10):
    with open(in_file) as f:
        lines = f.readlines()

    ocr_input = clean(remove_label_and_nl(lines[0]))
    ocr_aligned = remove_label_and_nl(lines[1])
    gs_aligned = remove_label_and_nl(lines[2])

    tokens = tokenize_aligned(ocr_aligned, gs_aligned)

    # Check data
    for token in tokens:
        input_token = ocr_input[token.start:token.start+token.len_ocr]
        try:
            assert token.ocr == input_token.strip()
        except AssertionError:
            print(f'Text: {str(in_file)}; ocr: {repr(token.ocr)}; ocr_input: {repr(input_token)}')
            raise

    ocr = clean(ocr_aligned)
    gs = clean(gs_aligned)
    ed = edlib.align(gs, ocr)['editDistance']
    score = normalized_ed(ed, ocr, gs)
    
    return Text(tokens, score)

text = process_text(in_dir/'NL'/'NL1'/'17.txt')

In [None]:
text.tokens[35]

In [None]:
text.score

In [None]:
%%time
import os

in_dir = Path('../../data/ICDAR2019_POCR_competition_dataset/ICDAR2019_POCR_competition_training_18M_without_Finnish')

data = {}

subsets = []
file_languages = []
file_names = []
scores = []
num_tokens = []

for language_dir in tqdm(in_dir.iterdir()):
    #print(language_dir.stem)
    language = language_dir.stem
    
    for text_file in language_dir.rglob('*.txt'):
        #print(text_file)
        #print(text_file.relative_to(in_dir))
        key = str(text_file.relative_to(in_dir))
        data[key] = process_text(text_file)

        file_languages.append(language)
        file_names.append(key)
        scores.append(data[key].score)
        num_tokens.append(len(data[key].tokens))
md = pd.DataFrame({'language': file_languages, 
                   'file_name': file_names,
                   'score': scores,
                   'num_tokens': num_tokens})

In [None]:
md

In [None]:
md.num_tokens.describe()

In [None]:
md.num_tokens.hist(bins=2000, figsize=(10,5))

In [None]:
md.score.describe()

In [None]:
md.score.hist(bins=50, figsize=(10,5))

In [None]:
md.query('score <= 0.3').num_tokens.describe()

In [None]:
from sklearn.model_selection import train_test_split

X_train, X_val, _, _ = train_test_split(md, md['file_name'], test_size=0.1, 
                                        shuffle=True, stratify=md['language'])

In [None]:
out_dir = Path('results')

X_train.to_csv(out_dir/'train.csv')
X_val.to_csv(out_dir/'val.csv')

In [None]:
out_dir = Path('results')

X_train = pd.read_csv(out_dir/'train.csv')
X_val = pd.read_csv(out_dir/'val.csv')

In [None]:
# Generate 'sentences' for train and val sets

def generate_sentences(df, data, size=15, overlap=10):
    sents = []
    labels = []
    keys = []
    start_tokens = []
    scores = []
    languages = []

    for idx, row in tqdm(df.iterrows()):
        key = row.file_name
        tokens = data[key].tokens

        # print(len(tokens))
        # print(key)
        for i, res in enumerate(window(tokens, size=size)):
            if i % overlap == 0:
                ocr = [t.ocr for t in res]
                gs = [t.gs for t in res]
                lbls = [0 if t.ocr == t.gs else 1 for t in res]

                ocr_str = ' '.join(ocr)
                gs_str = ' '.join(gs)
                ed = edlib.align(ocr_str, gs_str)['editDistance']
                score = normalized_ed(ed, ocr_str, gs_str)

                sents.append(ocr)
                labels.append(lbls)
                keys.append(key)
                start_tokens.append(i)
                scores.append(score)
                languages.append(key[:2])
    data = pd.DataFrame({
        'key': keys,
        'start_token_id': start_tokens,
        'score': scores,
        'tokens': sents,
        'tags': labels,
        'language': languages
    })

    return data

train_data = generate_sentences(X_train, data)
val_data = generate_sentences(X_val, data)

In [None]:
train_data.head()

In [None]:
train_data.to_json(out_dir/'icdar_train.jsonl', orient='records', lines=True)
val_data.to_json(out_dir/'icdar_val.jsonl', orient='records', lines=True)

In [None]:
from datasets import load_dataset

out_dir = Path('results')

data_files = {'train': str(out_dir/'icdar_train.jsonl'),
              'val': str(out_dir/'icdar_val.jsonl')}

icdar_dataset = load_dataset("json", data_files=data_files)

In [None]:
icdar_dataset

In [None]:
train_data.score.describe()

In [None]:
train_data.score.hist(bins=50, figsize=(10,5))

In [None]:
val_data.score.describe()

In [None]:
val_data.score.hist(bins=50, figsize=(10,5))

In [None]:
icdar_dataset = icdar_dataset.filter(lambda sample: sample['score'] <= 0.3)

In [None]:
icdar_dataset

In [None]:
icdar_dataset.save_to_disk('icdar-0.3')

In [None]:
from datasets import load_from_disk

icdar_dataset = load_from_disk('icdar-0.3')

In [None]:
icdar_dataset

In [None]:
from transformers import AutoTokenizer

tokenizer = AutoTokenizer.from_pretrained('bert-base-multilingual-cased')

In [None]:
# Source: https://huggingface.co/docs/transformers/custom_datasets#token-classification-with-wnut-emerging-entities
def tokenize_and_align_labels(examples):
    tokenized_inputs = tokenizer(examples["tokens"], truncation=True, is_split_into_words=True)

    labels = []
    for i, label in enumerate(examples[f"tags"]):
        word_ids = tokenized_inputs.word_ids(batch_index=i)  # Map tokens to their respective word.
        previous_word_idx = None
        label_ids = []
        for word_idx in word_ids:                            # Set the special tokens to -100.
            if word_idx is None:
                label_ids.append(-100)
            elif word_idx != previous_word_idx:              # Only label the first token of a given word.
                label_ids.append(label[word_idx])
        
            #previous_word_idx = word_idx

        labels.append(label_ids)

    tokenized_inputs["labels"] = labels

    return tokenized_inputs

In [None]:
tokenized_icdar = icdar_dataset.map(tokenize_and_align_labels, batched=True)

In [None]:
from transformers import DataCollatorForTokenClassification

data_collator = DataCollatorForTokenClassification(tokenizer)

In [None]:
train_samples = tokenized_icdar['train'].shuffle().select(range(5))
val_samples = tokenized_icdar['val'].shuffle().select(range(5))

In [None]:
from transformers import AutoModelForTokenClassification, Trainer, TrainingArguments

training_args = TrainingArguments(
    output_dir='./results-0.3',          # output directory
    evaluation_strategy="epoch",
    num_train_epochs=3,
)

model = AutoModelForTokenClassification.from_pretrained('bert-base-multilingual-cased', num_labels=2)

trainer = Trainer(
    model=model,                         # the instantiated 🤗 Transformers model to be trained
    args=training_args,                  # training arguments, defined above
    train_dataset=train_samples,         # training dataset
    eval_dataset=val_samples,            # evaluation dataset
    data_collator=data_collator,
    tokenizer=tokenizer,
)

trainer.train()

In [None]:
trainer.save_model()

In [None]:
samples = tokenized_icdar['val'].shuffle().select(range(5))

In [None]:
pred = trainer.predict(samples)

In [None]:
pred.predictions

In [None]:
from collections import defaultdict

def convert_predictions(samples, pred):
    print('samples', len(samples))

    tokenized_samples = tokenizer(samples["tokens"], truncation=True, is_split_into_words=True)

    print(type(tokenized_samples))
    
    # convert predictions to labels (label_ids)
    p = np.argmax(pred.predictions, axis=2)

    converted = defaultdict(dict)

    for i, (sample, preds) in enumerate(zip(samples, p)):
        label = sample['tags']
        print(label)
        print(len(preds), preds)
        word_ids = tokenized_samples.word_ids(batch_index=i)  # Map tokens to their respective word.
        print(len(word_ids), word_ids)
        result = defaultdict(list)
        for word_idx, p_label in zip(word_ids, preds):
            print(word_idx, p_label)
            if word_idx is not None:
                result[word_idx].append(p_label)
        
        new_tags = []
        for word_idx, preds in result.items():
            new_tag = 1 if 1 in preds else 0
            new_tags.append(new_tag)

        print('pred', len(new_tags), new_tags)
        print('tags', len(label), label)
        
        print(sample)
        print(sample['key'], sample['start_token_id'])
        converted[sample['key']][sample['start_token_id']] = new_tags

    return converted


result = convert_predictions(samples, pred)

In [None]:
result

In [None]:
output = {}
text_output = {}

for key, preds in result.items():
    labels = defaultdict(list)
    print(key)
    text = data[key]
    print(len(text.tokens))
    print(preds)
    for start, lbls in preds.items():
        for i, label in enumerate(lbls):
            labels[start+i].append(label)

    for i, token in enumerate(text.tokens):
        if 1 in labels[i]:
            print(i, token)
            num_tokens = 1 + token.ocr.count(' ')
            print(token.start, num_tokens)
            text_output[f'{token.start}:{num_tokens}'] = {}

    output[key] = text_output


In [None]:
output