In [1]:
% cd ../../

/home/mayu-ot/Documents/iparaphrasing-release


In [2]:
import pandas as pd
import numpy as np
import sklearn.metrics
from IPython.display import display_html

In [3]:
def split_data_word_count(df):
    # split validation set
    mult_idx = []
    single_idx = []

    for i, row in df.iterrows():
        phr1 = row['phrase1']
        phr2 = row['phrase2']

        if ('+' in phr1) and ('+' in phr2):
            mult_idx.append(i)
        elif ('+' not in phr1) and ('+' not in phr2):
            single_idx.append(i)
    
    return single_idx, mult_idx

def get_best_thresh(df, index=None):
    # eval on single word examples
    items = df if index is None else df.iloc[index]
    score = items.score
    ytrue = items.ytrue

    prec, rec, thresh = sklearn.metrics.precision_recall_curve(y_true=ytrue, probas_pred=score)
    f1 = 2 * (prec * rec) / (prec + rec)
    f1[np.isnan(f1)] = 0.
    best_i = f1.argmax()
    best_thresh = thresh[best_i]
    
    return best_thresh

def compute_scores(df, thresh, index=None):
    # eval_on multple word examples
    items = df if index is None else df.iloc[index]
    score = items.score
    ytrue = items.ytrue

    prec = sklearn.metrics.precision_score(y_true=ytrue, y_pred=score>thresh)
    rec = sklearn.metrics.recall_score(y_true=ytrue, y_pred=score>thresh)
    f1 = sklearn.metrics.f1_score(y_true=ytrue, y_pred=score>thresh)

    return {'precision': prec, 'recall': rec, 'f1': f1}
    

def get_results(base_dir):
    val_df = pd.read_csv(base_dir+'res_val.csv')
    test_df = pd.read_csv(base_dir + 'res_test.csv')
    
    val_single_idx, val_mult_idx = split_data_word_count(val_df)
    test_single_idx, test_mult_idx = split_data_word_count(test_df)
    
    best_thresh = get_best_thresh(val_df, index=val_single_idx)
    single_res = compute_scores(test_df, best_thresh, index=test_single_idx)
    
    best_thresh = get_best_thresh(val_df, index=val_mult_idx)
    mult_res = compute_scores(test_df, best_thresh, index=test_mult_idx)
    
    best_thresh = get_best_thresh(val_df)
    all_res = compute_scores(test_df, best_thresh)
    
    return single_res, mult_res, all_res

In [4]:
model_dir = {'SNN (WEA)': 'avr-None',
'SNN+image (WEA)' : 'avr-vgg',
'SNN (FV)': 'fv+pca-None',
'SNN+image (FV)': 'fv+pca-vgg',
'SNN (FV+CCA)': 'fv+cca-None',
'SNN+image (FV+CCA)': 'fv+cca-vgg',}
    
# 'Ensemble (WEA)': '',
# 'Ensemble (FV)': '',
# 'Ensemble (FV+CCA)': '',

head = '''
<style type="text/css">
.tg  {border-collapse:collapse;border-spacing:0;border:none;border-color:#ccc;}
.tg td{font-family:Arial, sans-serif;font-size:14px;padding:10px 5px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;border-color:#ccc;color:#333;background-color:#fff;}
.tg th{font-family:Arial, sans-serif;font-size:14px;font-weight:normal;padding:10px 5px;border-style:solid;border-width:0px;overflow:hidden;word-break:normal;border-color:#ccc;color:#333;background-color:#f0f0f0;}
.tg .tg-0ord{text-align:right}
.tg .tg-s6z2{text-align:center}
.tg .tg-34fq{font-weight:bold;text-align:right}
</style>
<table class="tg">
  <tr>
    <th class="tg-0ord"></th>
    <th class="tg-s6z2" colspan="3">Precision</th>
    <th class="tg-s6z2" colspan="3">Recall<br></th>
    <th class="tg-s6z2" colspan="3">F1</th>
  </tr>
'''

bottom='''
</table>
'''

row = '''
  <tr>
    <td class="tg-34fq">%s</td>
    <td class="tg-s6z2">%.2f<br></td>
    <td class="tg-031e">%.2f</td>
    <td class="tg-031e">%.2f</td>
    <td class="tg-s6z2">%.2f</td>
    <td class="tg-031e">%.2f</td>
    <td class="tg-031e">%.2f</td>
    <td class="tg-s6z2">%.2f</td>
    <td class="tg-031e">%.2f</td>
    <td class="tg-031e">%.2f</td>
  </tr>
'''

output_html = head
for method, base_dir in model_dir.items():
    single_res, mult_res, all_res = get_results('models/%s/' % base_dir)
    
    res_vals = [method]
    for met in ['precision', 'recall', 'f1']:
        for res in [all_res, single_res, mult_res]:
            res_vals.append(res[met] * 100)
    
    output_html += row % tuple(res_vals)

output_html += bottom

display_html(output_html, raw=True)

Unnamed: 0,Precision,Precision.1,Precision.2,Recall,Recall.1,Recall.2,F1,F1.1,F1.2
SNN (WEA),77.86,83.66,74.5,84.58,75.16,88.96,81.08,79.18,81.09
SNN+image (WEA),79.47,81.01,77.26,84.56,79.35,87.06,81.94,80.17,81.86
SNN (FV),64.21,45.92,66.4,65.93,50.89,76.51,65.06,48.28,71.1
SNN+image (FV),63.49,52.62,66.86,68.2,55.62,78.01,65.76,54.08,72.01
SNN (FV+CCA),83.11,85.19,77.44,82.13,79.3,87.69,82.62,82.14,82.25
SNN+image (FV+CCA),82.51,84.52,80.28,84.19,81.85,86.82,83.34,83.16,83.43
