In [106]:
from src.inference import evaluate_model
import pandas as pd
import numpy as np 
from mltu.utils.text_utils import ctc_decoder, get_cer, get_wer 


In [53]:

def interpret(model_path):
    cer_val, wer_val = evaluate_model(
    model_path=model_path, 
    datapath= model_path + '/val.csv' 
    )
    cer_test, wer_test = evaluate_model(
    model_path=model_path,
    datapath='data/testset.npy'
    )
    print("\t\tCER\t\tWER")
    #round up to 4 decimal places 
    cer_val = round(cer_val, 4)
    wer_val = round(wer_val, 4)
    cer_test = round(cer_test, 4)
    wer_test = round(wer_test, 4)
    
    print(f"Validation\t{cer_val}\t\t{wer_val}")
    print(f"Test\t\t{cer_test}\t\t{wer_test}")
    
    return (cer_val, wer_val, cer_test, wer_test)

### HTR NET  without the custom dataset

In [54]:
htr_net_no_cus = interpret('results/htr_net_no_custom_new_split')


100%|██████████| 900/900 [00:10<00:00, 87.59it/s]
100%|██████████| 247/247 [00:02<00:00, 99.47it/s] 

		CER		WER
Validation	0.152		0.29
Test		0.3601		0.6073





### HTR NET with the custom dataset

In [58]:
htr_net_cus = interpret('results/htr_full_data')

100%|██████████| 1128/1128 [00:12<00:00, 90.37it/s]
100%|██████████| 247/247 [00:02<00:00, 98.91it/s]

		CER		WER
Validation	0.091		0.2004
Test		0.2021		0.3846





### CNN-BILSTM from mltu WITHOUT custom dataset

In [96]:
cnn_bilstm_no_cust = interpret('results/cnn_bilstm_no_custom')


100%|██████████| 900/900 [00:05<00:00, 159.30it/s]
100%|██████████| 247/247 [00:01<00:00, 186.07it/s]

		CER		WER
Validation	0.1581		0.3033
Test		0.3559		0.5506





### CNN-BILSTM from mltu WITH custom dataset

In [64]:
cnn_bilstm_cust = interpret('results/cnn_bilstm_mltu_all_data')

100%|██████████| 1128/1128 [00:06<00:00, 169.75it/s]
100%|██████████| 247/247 [00:01<00:00, 184.24it/s]

		CER		WER
Validation	0.0948		0.195
Test		0.1761		0.3279





### Summary

In [189]:
print("\t\t\t---------------- CER ----------------")
print("Custom dataset \t\t WITHOUT \t\t WITH")
print("\t\t\t Val \t Test \t\t Val \t Test") 
print(f"CNN-BiLSTM-MLTU  \t{cnn_bilstm_no_cust[0]} \t {cnn_bilstm_no_cust[2]} \t {cnn_bilstm_cust[0]}  {cnn_bilstm_cust[2]}")
print(f"HTR-Net \t\t {htr_net_no_cus[0]} \t {htr_net_no_cus[2]} \t {htr_net_cus[0]} \t {htr_net_cus[2]}")
print("\t\t\t---------------- WER ----------------")
print("Custom dataset \t\t WITHOUT \t\t WITH")
print("\t\t\t Val \t Test \t\t Val \t Test")
print(f"CNN-BiLSTM-MLTU  \t{cnn_bilstm_no_cust[1]} \t {cnn_bilstm_no_cust[3]} \t {cnn_bilstm_cust[1]}\t {cnn_bilstm_cust[3]}")
print(f"HTR-Net \t\t{htr_net_no_cus[1]} \t {htr_net_no_cus[3]} \t {htr_net_cus[1]}  {htr_net_cus[3]}")


			---------------- CER ----------------
Custom dataset 		 WITHOUT 		 WITH
			 Val 	 Test 		 Val 	 Test
CNN-BiLSTM-MLTU  	0.1581 	 0.3559 	 0.0948  0.1761
HTR-Net 		 0.152 	 0.3601 	 0.091 	 0.2021
			---------------- WER ----------------
Custom dataset 		 WITHOUT 		 WITH
			 Val 	 Test 		 Val 	 Test
CNN-BiLSTM-MLTU  	0.3033 	 0.5506 	 0.195	 0.3279
HTR-Net 		0.29 	 0.6073 	 0.2004  0.3846


In [199]:
provided_preds_df = pd.read_csv('data/raw/chess_reader_data/prediciton.csv')
# open test set 
def get_perf_of_third_party(third_party_col): 
    test_set = np.load('data/testset.npy', allow_pickle=True) 
    # preds = provided_preds_df[third_party_col].values
    
    avg_cer = 0 
   
    avg_wer = 0 
    for (img_file_name,label) in test_set : 
        image_name = img_file_name.split('/')[-1] 
        # remove .png 
        image_id = image_name.split('.')[0] 
        # find rows that have image_name as their first column
        rows = provided_preds_df.loc[provided_preds_df["id"] == int(image_id)] 
        pred = str(rows[third_party_col].values[0]) 
        pred = pred if pred != 'nan' else '' 
        avg_cer += get_cer(label, pred)
        avg_wer += get_wer(pred, label) 
    avg_cer= round(avg_cer / len(test_set), 4)
    avg_wer = round(avg_wer / len(test_set), 4)
    return avg_cer,avg_wer


In [200]:
cer_google, wer_google = get_perf_of_third_party('gl')
cer_azure, wer_azure = get_perf_of_third_party('az') 
cer_abby, wer_azure = get_perf_of_third_party('ab')



In [213]:
print("---------------- Third Party OCR on Test Set ----------------")
print("\t\t  CER \t\t WER") 
print(f"Google\t\t  {cer_google} \t {wer_google}")
print(f"Azure\t\t  {cer_azure} \t {wer_azure}")
print(f"Abbyy\t\t  {cer_abby} \t {wer_azure}")
print("----------------- Our Models on Test Set ----------------")
print("\t\t  CER \t\t WER") 
print(f"CNN-BiLSTM-MLTU\t  {cnn_bilstm_cust[2]} \t {cnn_bilstm_cust[3]}")
print(f"HTR-Net\t\t  {htr_net_cus[2]} \t {htr_net_cus[3]}")


---------------- Third Party OCR on Test Set ----------------
		  CER 		 WER
Google		  0.2541 	 0.4413
Azure		  0.2877 	 0.3765
Abbyy		  0.2548 	 0.3765
----------------- Our Models on Test Set ----------------
		  CER 		 WER
CNN-BiLSTM-MLTU	  0.1761 	 0.3279
HTR-Net		  0.2021 	 0.3846
