In [None]:
import pandas as pd
import pickle
import numpy as np
from scipy.stats import pearsonr
from tqdm.notebook import tqdm, tqdm_notebook
import json
import seaborn as sns
import os
import matplotlib.pyplot as plt
import itertools
sns.set()

### Definitions

In [None]:
def cache(path, f):
    if not os.path.exists(path):
        result = f()
        json.dump(result, open(path, 'wt'))
        return result
    else:
        return json.load(open(path))

In [None]:
def generate(xs, df, true_means, trials=1000, group_name='system', score_name='score', ratio=-1):
    # need to make sure all systems have about the same # of ex
    # need to make sure that ratio is not very large
    groups = df.groupby(group_name, as_index=False)
    
    ys = []
    for x in tqdm(xs):
        x = int(x * ratio) if ratio > 1 else x
        data = []
        
        for i in range(0, trials):
            simulated_test_set = groups.apply(lambda g: g.sample(n=x, replace=True).mean())
            means = [ i for i in simulated_test_set[score_name] ]
            
            correct_pairs = 0
            total_pairs = 0
            for i, j in list(itertools.combinations(range(len(means)), 2)):
                total_pairs += 1
                if np.sign(true_means[i] - true_means[j]) == np.sign(means[i] - means[j]):
                    correct_pairs += 1
        
            acc = correct_pairs / total_pairs
            data.append(acc)
        ys.append((np.mean(data), np.median(data), np.percentile(data, 0.05), np.percentile(data, 0.95)))
    
    return ys

# Analysis

### Human scores

In [None]:
wmt19 = pickle.load(open('../wmt16-19-metrics-shared-task/wmt_metadata/pickles/wmt19_sys_metadata.pkl', 'rb'))
wmt19.lp.value_counts()

In [None]:
pooled_vars_toen = json.load(open('./data/vars/pooled_vars_toen.json'))
total_vars_toen = json.load(open('./data/vars/total_vars_toen.json'))

In [None]:
pooled_var, total_var = pooled_vars_toen['wmt19'], total_vars_toen['wmt19']
true_var = total_var - pooled_var
ratio = total_var / true_var

print(ratio)     # using x(ratio) more data is theoretical.

### Metric scores

In [None]:
metric_scores = pd.read_csv('./indexes.tsv', sep='\t', index_col=[0])
read_scores = lambda x: [ float(i.strip()) for i in list(open(x)) ]
metric_scores['sentbleu'] = read_scores('./scores/sentbleu_scores')
metric_scores['bleurt'] = read_scores('./scores/bleurt-base-128_scores')
metric_scores['bleurt'] = read_scores('./scores/bleurt-base-128_scores')
metric_scores['bert_score'] = [ float(i.split('\t')[2]) for i in list(open('./scores/score_bert-score'))[1:]]

### WMT2019 (*-en)

In [None]:
TRIALS=1000
all_ys = {}
for lp in wmt19.lp.unique():
    if not lp.endswith('en'):
        continue

    print(lp)
    df_lp =  wmt19[(wmt19.lp == lp) & (wmt19.type.isin(['SYSTEM', 'REPEAT']))]
    true = [ group.mean()['score'] for i, group in df_lp.groupby('system') ]

    
    xs = np.linspace(df_lp.groupby('system').count().min()['score'] / ratio, 0, 10, endpoint=False).astype(int)

    ys = {}
    try:
        os.makedirs('cache/pairwise/%s' % lp)
    except:
        pass
    
    ys['human'] = cache('cache/pairwise/%s/y_human.json' % lp, lambda : generate(xs, df_lp[['system', 'score']], true, trials=TRIALS))
    ys['theoretical'] = cache('cache/pairwise/%s/y_theoretical.json' % lp, lambda : generate(xs, df_lp[['system', 'score']], true, trials=TRIALS, ratio=ratio))

    metric_scores_lp = metric_scores[metric_scores.lp == lp]
    ys['bleurt'] = cache('cache/pairwise/%s/y_bleurt.json' % lp, lambda : generate(xs, metric_scores_lp[['system', 'bleurt']], true, score_name='bleurt', trials=TRIALS))
    ys['sentbleu'] = cache('cache/pairwise/%s/y_sentbleu.json' % lp, lambda : generate(xs, metric_scores_lp[['system', 'sentbleu']], true, score_name='sentbleu', trials=TRIALS))
    ys['bert_score'] = cache('cache/pairwise/%s/y_bert_score.json'% lp, lambda : generate(xs, metric_scores_lp[['system', 'bert_score']], true, score_name='bert_score', trials=TRIALS))
    
    plt.figure()
    plt.title(lp)
    for k, v in ys.items():
        sns.lineplot(x=xs, y=[i[0] for i in v], label=k)
    plt.savefig('figs/pairwise/%s.png' % lp)