In [1]:
import glob
import keras_ocr
import matplotlib.pyplot as plt

from functools import reduce
# 상대경로를 사용했다. 환경에 맞춰 경로를 변경해야한다.
from popeval.popEval import process, make_pair, _divide

In [2]:
pipeline = keras_ocr.pipeline.Pipeline()

Looking for C:\Users\kwansu\.keras-ocr\craft_mlt_25k.h5
Looking for C:\Users\kwansu\.keras-ocr\crnn_kurapan.h5


In [10]:
image_files = sorted(glob.glob('data/img_*.png'))

images = [keras_ocr.tools.read(path) for path in image_files]
prediction_groups = [pipeline.recognize([path]) for path in image_files]

In [11]:
def create_converted_format_file(load_path, save_path):
    ''' 수작업으로 작성한 좌표정보를 포맷에 맞게 변형하여 저장한다.
    텍스트에서 줄마다 하나의 검출된 영역과 그 영역에서의 단어를 나타낸다.
    직사각형 박스는 (x   y   w   h   word)로 기록하였다.
    사다리꼴은 (@   x1   y1   ...  x4   y4   word)이며 @로 시작한다.
    단어 사이의 띄어쓰기가 있을수 있어서 구분은 탭(\t)으로만 하였다.
    추가로 keras-ocr에서 모두 소문자로만 나와서 문자를 소문자로 변경
    하였는데, 필요하다면 특수문자나 띄어쓰기도 처리하면 좋을 듯하다.'''

    with open(load_path, "r") as f:
        lines = f.readlines()

    with open(save_path, "w") as f:
        for line in lines:
            words = line.split('\t')
            if words[0] == '@':     # 사다리꼴 검출
                line = ' '.join(words[1:9]) + f' ##::{words[-1].lower()}'
            else:                   # 직사각형 검출
                x, y, w, h = map(int, words[:4])
                line = f"{x} {y} {x+w} {y} {x+w} {y+h} {x} {y+h}##::{words[-1].lower()}"
            f.writelines(line)

In [12]:
for i in range(1, len(image_files)+1):
    load_path = f'data/coord_{i:03d}.txt'       # 노가다로 만든 텍스트 경로
    save_path = f'data/GT_{i:03d}.txt'          # 포맷에 맞게 저장할 경로
    create_converted_format_file(load_path, save_path)

In [13]:
for i, pred in enumerate(prediction_groups, start=1):
    save_path = f'data/Pred_{i:03d}.txt'
    with open(save_path, "w") as f:
        for img_info in pred:
            for line in img_info:
                for i in range(4):
                    f.write(f'{int(line[1][i][0])} {int(line[1][i][1])} ')
                f.write(f'##::{line[0]}\n')

In [14]:
def evaluate(gt_files, pred_files, dontcare_text='###'):
    removed_gt_char_count = 0
    precision_list = []
    recall_list = []

    total_removed_gt_char_count = 0
    total_pred_char_count = 0
    total_gt_chars_count = 0

    # 기존 멀티프로세스로 되어있던것을 변경하였다. 느리다면 기존 코드를 참고해서 변경할것.
    for result in map(process, gt_files, pred_files, [dontcare_text]*len(pred_files)):
        try:
            precision, recall, removed_gt_char_count, pred_char_count, gt_char_count = result
            total_removed_gt_char_count += removed_gt_char_count
            total_pred_char_count += pred_char_count
            total_gt_chars_count += gt_char_count
            precision_list.append(precision)
            recall_list.append(recall)
        except Exception as e:
            print(e)

    precision_for_char = _divide(
        float(total_removed_gt_char_count), float(total_pred_char_count))
    recall_for_char = _divide(
        float(total_removed_gt_char_count), float(total_gt_chars_count))
    #precision_avr = _divide(reduce(lambda x, y: x + y, precision_list, 0), len(precision_list))
    #recall_avr = _divide(reduce(lambda x, y: x + y, recall_list, 0), len(recall_list))
    perf = _divide(2*(precision_for_char*recall_for_char),
                   (precision_for_char + recall_for_char))

    print(' num | precision |  recall  |')
    for i, (precision, recall) in enumerate(zip(precision_list, recall_list), start=1):
        print(f' {i:03d} | {precision:f}  | {recall:f} |')
    print("======================")
    return precision_for_char, recall_for_char, perf


In [15]:
GT_files = sorted(glob.glob('data/GT_*.txt'))
D_files = sorted(glob.glob('data/Pred_*.txt'))

if len(GT_files) != len(D_files):
    print("Caution: GT_files' len(%d) and D_files' len(%d) are different."%(len(GT_files), len(D_files)))
    GT_files, D_files = make_pair(GT_files, D_files)
    print("We will evaluate on %d files"%(len(GT_files)))

pr, re, pref = evaluate(GT_files, D_files)
print("precision, recall, H:")
print("%0.1f, %0.1f, %0.1f"%(100.*pr, 100.*re, 100.*pref))

 num | precision |  recall  |
 001 | 0.989130  | 0.968085 |
 002 | 0.957447  | 0.900000 |
precision, recall, H:
97.8, 94.4, 96.1
