In [14]:
from glob import glob
import matplotlib.pyplot as plt
import sys
import os
import numpy as np
from subprocess import check_call
from tempfile import NamedTemporaryFile as TempFile
sys.path.insert(0, '/private/home/mattle/ContraPro')
import plotly.graph_objs as go
from evaluate import count_errors, print_all_stats  # import the ContraPro evaluate.py file
plotly.offline.init_notebook_mode(connected=True)

In [15]:
def get_scores(base):
    ref = '/private/home/mattle/ContraPro/contrapro.json'
    all_res = []
    for i in range(0, 5):
        with TempFile() as tf:
            logfile = os.path.join(base, f'doc-eval-logs_n_ctxt_{i}.log')
            check_call(f'cat {logfile} | grep "^H-" | cut -f2 > {tf.name}', shell=True)
            all_res.append(count_errors(open(ref, 'r'), open(tf.name, 'r'), maximize=True, verbose=False))
    return all_res

def mk_plots(all_res):
    for category in ['by_category', 'by_intrasegmental', 'by_ante_distance']:
        legend = []
        for k, res in all_res.items():
            labels = sorted(list(res[category].keys()))
            values = np.array([res[category][k]['correct'] / res[category][k]['total'] for k in labels])
            labels = list(map(str, labels))
            plt.plot(labels, values)
            legend.append(k)
        plt.legend(legend)
        plt.title(category)
        plt.show()
        
def mk_plotly(all_res):
    for category in ['by_category', 'by_intrasegmental', 'by_ante_distance']:
        legend = []
        fig = go.Figure()
        for k, res in all_res.items():
            labels = sorted(list(res[category].keys()))
            values = np.array([res[category][k]['correct'] / res[category][k]['total'] for k in labels])
            labels = list(map(str, labels))
            fig.add_trace(go.Scatter(x=labels, y=values, mode='lines', name=k, hoverlabel = dict(namelength = -1)))
        fig.update_layout(
            title=category,
            xaxis_title=category,
            yaxis_title='Accuracy'
        )
        fig.show()

In [16]:
finetuned = get_scores('/checkpoint/mattle/2019-08-07/finetune_wmt.lr_0.0007.en-de.ngpu4')
orig_wmt_ensemble = get_scores('/checkpoint/mattle/wmt-19-checkpoints/')
all_res = {
    **{f'finetuned_ctxt_{i}': x for i, x in enumerate(finetuned)},
    **{f'orig_wmt_ctxt_{i}': x for i, x in enumerate(orig_wmt_ensemble)}
}

In [17]:
# Print results for context length of 4
# The first 4 tables are in the format (correct, total_samples, accuracy)
print_all_stats(finetuned[4])

total : 10043 12000 0.8369166666666666

statistics by error category
it:er : 3082 4000 0.7705
it:es : 3665 4000 0.91625
it:sie : 3296 4000 0.824

statistics by intrasegmental
False : 7969 9600 0.8301041666666666 
True : 2074 2400 0.8641666666666666 

statistics by ante distance
0 : 2074 2400 0.8641666666666666 
1 : 5849 7075 0.8267137809187279 
2 : 1253 1510 0.8298013245033112 
3 : 479 573 0.8359511343804538 
>3 : 388 442 0.8778280542986425 

ante distance per pronoun pairs
ante distance 0 :
it:er 736 
it:es 872 
it:sie 792 
ante distance 1 :
it:er 2577 
it:es 1892 
it:sie 2606 
ante distance 2 :
it:er 459 
it:es 631 
it:sie 420 
ante distance 3 :
it:er 167 
it:es 274 
it:sie 132 
ante distance >3 :
it:er 61 
it:es 331 
it:sie 50 


In [None]:
# Plot by different breakdowns for context sizes 0-4
mk_plotly(all_res)