In [None]:
%load_ext autoreload
%autoreload 2

In [None]:
%matplotlib inline

import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

from pathlib import Path

from tqdm.notebook import tqdm

In [None]:
import os
os.environ['LOGURU_LEVEL'] = 'INFO'

In [None]:
import logging

from loguru import logger

class InterceptHandler(logging.Handler):
    def emit(self, record):
        # Get corresponding Loguru level if it exists
        try:
            level = logger.level(record.levelname).name
        except ValueError:
            level = record.levelno

        # Find caller from where originated the logged message
        frame, depth = logging.currentframe(), 2
        while frame.f_code.co_filename == logging.__file__:
            frame = frame.f_back
            depth += 1

        logger.opt(depth=depth, exception=record.exc_info).log(level, record.getMessage())

logging.basicConfig(handlers=[InterceptHandler()], level=0)

In [None]:
from datautils import generate_data

in_dir = Path('../../data/ICDAR2019_POCR_competition_dataset/ICDAR2019_POCR_competition_training_18M_without_Finnish')
data, md = generate_data(in_dir)

In [None]:
from datautils import generate_data

in_dir = Path('../../data/ICDAR2019_POCR_competition_dataset/ICDAR2019_POCR_competition_evaluation_4M_without_Finnish')
data_test, X_test = generate_data(in_dir)

In [None]:
from datautils import remove_label_and_nl

def check_tokens(in_dir, key, data):

    in_file = in_dir/key

    with open(in_file) as f:
        lines = f.readlines()
        ocr_input = remove_label_and_nl(lines[0])
        ocr_aligned = remove_label_and_nl(lines[1])
        gs_aligned = remove_label_and_nl(lines[2])

    text = data[key]

    #assert ocr_input == ocr_aligned.replace('@', '')

    #logger.info(f'Checking input tokens of {key}')
    for token in text.input_tokens:
        #logger.info(token)
        inp = ocr_input[token.start:(token.start+token.len_ocr)]
        try:
            assert inp == token.ocr, f'"{inp}" != "{token.ocr}"'
        except AssertionError:
            logger.info(f'"{inp}" != "{token.ocr}" for token {token}')
            raise
        

    #logger.info(f'Checking aligned tokens of {key}')

    #print(ocr_input)

check_tokens(in_dir, 'DE/DE6/60.txt', data_test)
#check_tokens(in_dir, 'SL/SL1/40.txt', data_test)
#check_tokens(in_dir, 'SL/SL1/17.txt', data_test)

In [None]:
num_errors = 0

for key, _ in data_test.items():
    try:
        check_tokens(in_dir, key, data_test)
    except AssertionError:
        logger.info(f'Error in {key}')
        num_errors += 1

In [None]:
num_errors

In [None]:
num_errors/len(data_test)

In [None]:
import json

with open('condensed_predictions_task1.json', 'r') as f:
    result = json.load(f)

In [None]:
result

In [None]:
import re

def extract_icdar_output(label_str, input_tokens):
    text_output = {}

    # Correct use of 2 (always following a 1)
    regex = r'12*'

    for match in re.finditer(regex, label_str):
        #print(match)
        #print(match.group())
        num_tokens = len(match.group())
        idx = input_tokens[match.start()].start
        text_output[f'{idx}:{num_tokens}'] = {}

    # Incorrect use of 2 (following a 0) -> interpret first 2 as 1
    regex = r'02+'

    for match in re.finditer(regex, label_str):
        #print(match)
        #print(match.group())
        num_tokens = len(match.group()) - 1
        idx = input_tokens[match.start()+1].start
        text_output[f'{idx}:{num_tokens}'] = {}
    
    return text_output

#label_str = '12200010011120020222'
#output = extract_icdar_output(label_str, data['DE/DE3/1988.txt'].input_tokens)
#output

In [None]:
from collections import defaultdict

output = {}

for key, preds in result.items():
    labels = defaultdict(list)
    #print(key)
    try:
        text = data_test[key]
        #print(len(text.input_tokens))
        #print(preds)
        for start, lbls in preds.items():
            #print(start, type(start))
            for i, label in enumerate(lbls):
                labels[int(start)+i].append(label)
        #print('LABELS')
        #print(labels)

        label_str = []

        for i, token in enumerate(text.input_tokens):
            #print(i, token, labels[i])
            if 2 in labels[i]:
                label_str.append('2')
            elif 1 in labels[i]:
                label_str.append('1')
            else:
                label_str.append('0')
        label_str = ''.join(label_str)

        #print('LABEL STR')
        #print(label_str)

        output[key] = extract_icdar_output(label_str, text.input_tokens)
    except KeyError:
        logger.warning(f'No data found for text {key}')

In [None]:
output

In [None]:
import json

with open('results_task1_new.json', 'w') as f:
    json.dump(output, f)

In [None]:
!python evalTool_ICDAR2017.py ../../data/ICDAR2019_POCR_competition_dataset/ICDAR2019_POCR_competition_evaluation_4M_without_Finnish results_task1_new.json results_task1.csv