In [45]:
import pandas
import os 

In [47]:
import xlwt 
from xlwt import Workbook 

In [48]:
from tqdm.autonotebook import tqdm

In [49]:
def dump_xls(dic, output_filename, sheet="Sheet 1"):
    wb = Workbook() 
    sheet1 = wb.add_sheet(sheet) 
    l = len(dic)
    count=0
    for key in dic.keys():
        sheet1.write(count, 0, key)
        sheet1.write(count, 1, dic[key][0])
        sheet1.write(count, 2, dic[key][1])
        sheet1.write(count, 3, dic[key][2])
        count+=1
    wb.save(output_filename)

In [51]:
def get_frame_dict(root, n_inter):
    file =  "Dumped_csv_"+str(n_inter)+".csv"
    df = pandas.read_csv(os.path.join(root, file))
    frame_dict = df.to_dict(orient="list")
    return frame_dict

def example_output_map(frame_dict):
    mapping={}
    for i in range(len(frame_dict["Input sentence"])):
        mapping[frame_dict["Input sentence"][i]] = (frame_dict["Correct output"][i], frame_dict["predicted_output"][i])
    return mapping

def get_all_stats(comm):
    both_wrong={}
    model1_corr={}
    model2_corr={}
    for key in comm.keys():
        [(corr, pred_m1), (corr, pred_m2)] = comm[key]    
        if pred_m1 !=corr and pred_m1==pred_m2:
            both_wrong[key] = (corr, pred_m1, pred_m2)
        elif pred_m1!=pred_m2 and pred_m1==corr:
            model1_corr[key] = (corr, pred_m1, pred_m2)
        elif pred_m2==corr and pred_m1!=pred_m2:
            model2_corr[key] = (corr, pred_m1, pred_m2)
        
    return both_wrong, model1_corr, model2_corr

def get_common_examples(frame_dict, mapping):
    comm={}
    for i in range(len(frame_dict["Input sentence"])):
        if frame_dict["Input sentence"][i] in mapping.keys():
            comm[frame_dict["Input sentence"][i]]= [mapping[frame_dict["Input sentence"][i]],(frame_dict["Correct output"][i], frame_dict["predicted_output"][i])]
    return comm

In [62]:
root1 = "decay\\"
root2 = "eirnn_seq\\"
root3 = "gru\\"
model1 = "decay"  # root1
model2 = "eirnn"  #root2
model3 = "gru" #root3

In [69]:
def compare_models(model1, model2, root1, root2, dump_root):
    for i in tqdm(range(8)):
        mapping = example_output_map(get_frame_dict(root1, n_inter=i))
        comm = get_common_examples(get_frame_dict(root2, n_inter=i), mapping)
        both_wrong, model1_corr, model2_corr = get_all_stats(comm)
        file1 = os.path.join(dump_root, "{}_{}_{}_both_wrong.xls".format(model1, model2, i))
        file2 = os.path.join(dump_root,"{}_{}_{}_model1_corr.xls".format(model1, model2, i))
        file3 = os.path.join(dump_root, "{}_{}_{}_model2_corr.xls".format(model1, model2, i))
        print("{}_{}_{}_both_wrong {}".format(model1, model2, i, len(both_wrong)))
        print("{}_{}_{}_model1_corr {}".format(model1, model2, i, len(model1_corr)))
        print("{}_{}_{}_model2_corr {}".format(model1, model2, i, len(model2_corr)))
        dump_xls(both_wrong, file1)
        dump_xls(model1_corr, file2)
        dump_xls(model2_corr, file3)

In [70]:
compare_models("decay", "eirnn", root1, root2,"decay_vs_eirnn")

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

decay_eirnn_0_both_wrong 3504
decay_eirnn_0_model1_corr 6271
decay_eirnn_0_model2_corr 5333
decay_eirnn_1_both_wrong 2186
decay_eirnn_1_model1_corr 3624
decay_eirnn_1_model2_corr 1982
decay_eirnn_2_both_wrong 1495
decay_eirnn_2_model1_corr 2283
decay_eirnn_2_model2_corr 1162
decay_eirnn_3_both_wrong 822
decay_eirnn_3_model1_corr 1201
decay_eirnn_3_model2_corr 548
decay_eirnn_4_both_wrong 501
decay_eirnn_4_model1_corr 562
decay_eirnn_4_model2_corr 260
decay_eirnn_5_both_wrong 273
decay_eirnn_5_model1_corr 302
decay_eirnn_5_model2_corr 115
decay_eirnn_6_both_wrong 141
decay_eirnn_6_model1_corr 119
decay_eirnn_6_model2_corr 60
decay_eirnn_7_both_wrong 101
decay_eirnn_7_model1_corr 72
decay_eirnn_7_model2_corr 34



In [71]:
compare_models("decay", "gru", root1, root3,"decay_vs_gru")

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

decay_gru_0_both_wrong 3153
decay_gru_0_model1_corr 5445
decay_gru_0_model2_corr 5684
decay_gru_1_both_wrong 1550
decay_gru_1_model1_corr 1965
decay_gru_1_model2_corr 2618
decay_gru_2_both_wrong 1031
decay_gru_2_model1_corr 1152
decay_gru_2_model2_corr 1626
decay_gru_3_both_wrong 531
decay_gru_3_model1_corr 550
decay_gru_3_model2_corr 839
decay_gru_4_both_wrong 321
decay_gru_4_model1_corr 268
decay_gru_4_model2_corr 440
decay_gru_5_both_wrong 166
decay_gru_5_model1_corr 135
decay_gru_5_model2_corr 222
decay_gru_6_both_wrong 95
decay_gru_6_model1_corr 60
decay_gru_6_model2_corr 106
decay_gru_7_both_wrong 62
decay_gru_7_model1_corr 36
decay_gru_7_model2_corr 73



In [72]:
compare_models("eirnn", "gru", root2, root3,"eirnn_vs_gru")

HBox(children=(IntProgress(value=0, max=8), HTML(value='')))

eirnn_gru_0_both_wrong 2977
eirnn_gru_0_model1_corr 5621
eirnn_gru_0_model2_corr 6798
eirnn_gru_1_both_wrong 1756
eirnn_gru_1_model1_corr 1759
eirnn_gru_1_model2_corr 4054
eirnn_gru_2_both_wrong 1160
eirnn_gru_2_model1_corr 1023
eirnn_gru_2_model2_corr 2618
eirnn_gru_3_both_wrong 637
eirnn_gru_3_model1_corr 444
eirnn_gru_3_model2_corr 1386
eirnn_gru_4_both_wrong 365
eirnn_gru_4_model1_corr 224
eirnn_gru_4_model2_corr 698
eirnn_gru_5_both_wrong 205
eirnn_gru_5_model1_corr 96
eirnn_gru_5_model2_corr 370
eirnn_gru_6_both_wrong 117
eirnn_gru_6_model1_corr 38
eirnn_gru_6_model2_corr 143
eirnn_gru_7_both_wrong 80
eirnn_gru_7_model1_corr 18
eirnn_gru_7_model2_corr 93

