# my metric

In [1]:
import os
# os.environ['KMP_DUPLICATE_LIB_OK']='True'

In [2]:
import torch
import easyocr
import dataset as DS
import pandas as pd
from finetune import validation #as testing

import typing as tp
from utils import CTCLabelConverter

from nltk.metrics.distance import edit_distance
import numpy as np

In [3]:
def load_readers_ori(model_path):
    reader = easyocr.Reader(['ch_tra'])
    checkpoint = torch.load(model_path)
    reader.recognizer.load_state_dict(checkpoint)
    return reader

def get_preds(reader, image_folder):
    # t0 = time.time()
    preds = {}
    for filename in os.listdir(image_folder):
        if filename.endswith('.jpg'): #.jpg
            image_path = os.path.join(image_folder, filename)
            ocr_result = reader.readtext(image_path)
            ocr_text = ' '.join([result[1] for result in ocr_result])
            preds[filename] = ocr_text
#     print('time spent: ', time.time()-t0)
#     print(f'preds:')
#     print(preds)
    return preds

def calculate_scores(GTs, preds):
    n_correct = 0
    norm_EDs = []
    confidence_score_list = []

    for image_name, gt_label in GTs.items():
        pred_label = preds.get(image_name, '')  
        if pred_label == gt_label:
            n_correct += 1

        if len(gt_label) == 0 or len(pred_label) == 0:
            norm_ED = 0
        elif len(gt_label) > len(pred_label):
            norm_ED = 1 - edit_distance(pred_label, gt_label) / len(gt_label)
        else:
            norm_ED = 1 - edit_distance(pred_label, gt_label) / len(pred_label)

        norm_EDs.append(norm_ED)

    accuracy = n_correct / len(GTs)

    result = {
        "Accuracy": accuracy,
        "Norm_ED": np.asarray(norm_EDs).mean(),
    }
    return result

def load_readers(lang_list:tp.List[str], model_path:str=None):
    import easyocr
    def get_training_convertor(ref_converter:easyocr.utils.CTCLabelConverter):
        if isinstance(ref_converter, CTCLabelConverter):
            return ref_converter
        character = ''.join(ref_converter.character[1:])
        converter = CTCLabelConverter(character)
        converter.separator_list = ref_converter.separator_list
        converter.ignore_idx = ref_converter.ignore_idx
        converter.dict_list = ref_converter.dict_list
        converter.dict = ref_converter.dict
        return converter
    reader = easyocr.Reader(lang_list)
    if model_path:
        checkpoint = torch.load(model_path)
        reader.recognizer.load_state_dict(checkpoint)
    
    ref_converter = reader.converter
    training_converter = get_training_convertor(ref_converter)
    return reader.recognizer, training_converter, reader

In [4]:
# FT model
model_path_2 = "saved_models/epoch_6701.pth"
# reader2 = load_readers_ori(model_path_2)
model, converter, reader2 = load_readers(['ch_tra'], model_path_2)




In [5]:
# testing data
image_folder = "all_data/0_v3_test_2_1/0_v3_test_2_1/"
label_file = image_folder+"labels.csv"

# GT
df = pd.read_csv(label_file, encoding='big5')
df["SunnyGts"] = df.apply(lambda row: ' '.join(row['words'].split()), axis=1)
print(df)
GTs = {row["filename"]: row["SunnyGts"] for _, row in df.iterrows()}
print(GTs)

             filename      words   SunnyGts
0    1Tainan_1014.jpg         新化         新化
1    1Tainan_1015.jpg         阿蓮         阿蓮
2    1Tainan_1020.jpg         阿蓮         阿蓮
3    1Tainan_1038.jpg         新化         新化
4     1Tainan_104.jpg  臺灣智駕測試實驗室  臺灣智駕測試實驗室
..                ...        ...        ...
148   3Taoyuan_74.jpg       接機大廳       接機大廳
149   3Taoyuan_75.jpg  停車場分鐘免費停車  停車場分鐘免費停車
150   3Taoyuan_78.jpg       航郵中心       航郵中心
151   3Taoyuan_83.jpg       第一航廈       第一航廈
152   3Taoyuan_85.jpg       機場旅館       機場旅館

[153 rows x 3 columns]
{'1Tainan_1014.jpg': '新化', '1Tainan_1015.jpg': '阿蓮', '1Tainan_1020.jpg': '阿蓮', '1Tainan_1038.jpg': '新化', '1Tainan_104.jpg': '臺灣智駕測試實驗室', '1Tainan_1041.jpg': '台鐵沙崙站', '1Tainan_1045.jpg': '高鐵台南站', '1Tainan_1046.jpg': '台鐵沙崙站', '1Tainan_1048.jpg': '新化', '1Tainan_1055.jpg': '高鐵台南站', '1Tainan_1064.jpg': '阿蓮', '1Tainan_112.jpg': '大臺南會展中心', '1Tainan_124.jpg': '高鐵台南站', '1Tainan_136.jpg': '臺灣智駕測試實驗室', '1Tainan_145.jpg': '資安暨智慧科技研發大樓', '1Tainan_147.jpg': 

In [6]:
# pred
preds2 = get_preds(reader2, image_folder)
print(preds2)
# preds3 = get_preds(reader3, image_folder)

{'1Tainan_654.jpg': '汽車停車場', '1Tainan_871.jpg': '', '3Taoyuan_61.jpg': '', '1Tainan_781.jpg': '台鐵沙崙站', '1Tainan_672.jpg': '湖內 ', '1Tainan_328.jpg': '臨時接送區', '3Taoyuan_151.jpg': '院', '1Tainan_145.jpg': '臺', '1Tainan_18.jpg': '', '1Tainan_297.jpg': '', '1Tainan_216.jpg': '', '1Tainan_646.jpg': '沙崙站  站站', '1Tainan_147.jpg': '試', '1Tainan_378.jpg': '資智', '1Tainan_250.jpg': '快速公路', '1Tainan_589.jpg': '歸仁市區', '1Tainan_901.jpg': '', '1Tainan_837.jpg': '', '1Tainan_710.jpg': ' 大 ', '1Tainan_1020.jpg': '心', '1Tainan_689.jpg': '大南竹展中心', '3Taoyuan_157.jpg': ' 第二航廈', '1Tainan_616.jpg': '仁 德', '1Tainan_986.jpg': ' 歸仁市區', '1Tainan_821.jpg': '台鐵沙崙站', '3Taoyuan_74.jpg': '技技大', '3Taoyuan_75.jpg': '', '1Tainan_525.jpg': '臨停接送區', '1Tainan_1046.jpg': '台局', '1Tainan_279.jpg': '', '3Taoyuan_118.jpg': '航 駕 局', '1Tainan_636.jpg': '快速公路', '1Tainan_391.jpg': '中研究院南院', '1Tainan_510.jpg': '', '1Tainan_956.jpg': '', '2Penghu_27.jpg': '同和路', '1Tainan_667.jpg': '快速公路', '1Tainan_192.jpg': '高發二路', '1Tainan_277.jpg': '

In [7]:
# score
scores2 = calculate_scores(GTs, preds2)
scores2

{'Accuracy': 0.1895424836601307, 'Norm_ED': 0.3930445915740033}

# original metric from model

In [8]:
# model_path_2 = r'C:\Users\MAGIC\Ching\survey\plan\OCR\5\EasyOCR-master\finetune\saved_models\Freeze_FeatureExtraction\20231125_165812\epoch_6701.pth'
testing_set_roots = [image_folder]
character = ''.join(converter.character[1:])

DEVICE= torch.device('cuda' if torch.cuda.is_available() else 'cpu')
criterion = torch.nn.CTCLoss(zero_infinity=True).to(DEVICE)

test_loader = DS.load_dataset(*testing_set_roots, character=character)

# FT model
# model, converter = load_readers(["ch_tra"], model_path_2)
test_result2 = validation(model, criterion, converter, test_loader)
print(test_result2)

validation phase:   0%|          | 0/5 [00:00<?, ?it/s]

validation phase: 100%|██████████| 5/5 [00:03<00:00,  1.40it/s]

{'CTCLoss': 2.4746318, 'Accuracy': 0.30718954248366015, 'Norm_ED': 0.5840935027209537}





In [None]:
test_result3 = validation(model, criterion, reader2.converter, test_loader)
print(test_result3)

In [9]:

# original model
model0, converter0, reader0 = load_readers(["ch_tra"])
test_result0 = validation(model0, criterion, converter0, test_loader)
test_result0

validation phase: 100%|██████████| 5/5 [00:03<00:00,  1.53it/s]


{'CTCLoss': 9.622093,
 'Accuracy': 0.026143790849673203,
 'Norm_ED': 0.18773820587546078}