In [None]:
import json
import os
from os.path import realpath, dirname, join
import sys

import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
import seaborn as sns

ROOT = './data/'
OUTPUT_DIR = join(ROOT, 'plots')
os.makedirs(OUTPUT_DIR, exist_ok=True)

In [None]:
sns.set_style("whitegrid")
sns.set_palette("PuBuGn_d")

In [None]:
def extract_values(model_name, row):
    extracted = pd.DataFrame(columns=['Model', 'Benchmark', 'Score'])
    s = pd.Series(index=extracted.columns)
    for k, entry in row.items():
        s['Model'] = model_name
        
        if isinstance(entry, str):
            try:
                l = json.loads(entry.replace('(', '[').replace(')', ']'))
            except json.JSONDecodeError as e:
                print('For model {}, could not extract pair from entry: {}'.format(
                    model_name, entry))
                raise e
            # Pairs of numbers represent (spearman, pearson)
            s['Benchmark'] = k + '\nspearman'
            s['Score'] = float(l[0])
            extracted = extracted.append(s, ignore_index=True)
            #s['Benchmark'] = k + ' - pearson'
            #s['Score'] = float(l[1])
            #extracted = extracted.append(s, ignore_index=True)
        else:
            s['Benchmark'] = k
            s['Score'] = entry
            extracted = extracted.append(s, ignore_index=True)
    return extracted

def extract_model_name(row):
    import json
    n = row['outputmodelname']
    try:
        s = json.loads(n.replace("'", '"'))
        return s[0]
    except ValueError as e:
        print(e)
        pass
    return n

def plot_evaluation_results(eval_filenames, baselines=None):
    extracted = None
    model_names = {}
    for fname in eval_filenames:
        df = pd.read_csv(fname, sep=';', header=0)
        for i, row in df.iterrows():
            evals = json.loads(row['downstream_tasks'].replace("'", '"'))
            # There's one model per row
            model_name = extract_model_name(row)
            if model_name in model_names:
                print('Model name "{}" seen in two files: \n- {}\n- {}'.format(
                    model_name, fname, model_names[model_name]
                ), file=sys.stderr)
            model_names[model_name] = fname
            ee = extract_values(model_name, row[evals])
            if extracted is None:
                extracted = ee
            else:
                extracted = extracted.append(ee)
          
    fig, ax = plt.subplots(2, 1, figsize=(16, 2*7))
    for i, coeff_metrics in enumerate([False, True]):
        mask = extracted['Benchmark'].str.contains("spearman|pearson")
        if not coeff_metrics:
            mask = ~mask
        selected = extracted[mask]
        benchmarks = sorted(selected['Benchmark'].unique())
        
        sns.barplot(x='Benchmark', y='Score', hue='Model', data=selected, order=benchmarks, ax=ax[i])
        ax[i].set_title('Evaluation results on several variants')
        ax[i].set_xticklabels(ax[i].get_xticklabels(), rotation=45, horizontalalignment='right')
        
        if baselines is not None:
            baseline_mask = selected['Model'].str.match('|'.join(baselines))
            baseline_results = selected[baseline_mask]
            # Draw the high-line for each metric (whichever baseline is best for that metric)
            length = (ax[i].get_xlim()[1] - ax[i].get_xlim()[0]) / len(benchmarks)
            for j, bench in enumerate(benchmarks):
                max_val = baseline_results[baseline_results['Benchmark'] == bench]['Score'].max()
                ax[i].plot([length*(j-0.5), length*(j+0.5)], [max_val, max_val], '-k', alpha=0.35)
        
    plt.tight_layout()
    plt.savefig(join(OUTPUT_DIR, 'comparison.pdf'), bbox_inches='tight', dpi=128)

In [None]:
sources = [
    './evaluation-cbow-784-10p/evaluation.csv',
    './evaluation-cmow-784-10p/evaluation.csv',
    './evaluation-hybrid-800-10p/evaluation.csv',
    
    './evaluation-cnmow1-784-10p/evaluation.csv',
    './evaluation-cnmow1b-784-10p/evaluation.csv',
    './evaluation-cnmow2-784-10p/evaluation.csv',
    './evaluation-cnmow2b-784-10p/evaluation.csv',
    './evaluation-cnmow5-784-10p/evaluation.csv',
    './evaluation-cnmow5-hybrid-800-10p/evaluation.csv',
    #'./evaluation-cnmow6-784-10p/evaluation.csv',
    './evaluation-cnmow6-hybrid-800-10p/evaluation.csv',
    './evaluation-cnmow8-784-10p/evaluation.csv',
    './evaluation-cnmow9-784-10p/evaluation.csv',
]
baselines = [
    'cbow-784-10p',
    'cmow-784-10p',
    'hybrid-800-10p',   
]
plot_evaluation_results([join(ROOT, source) for source in sources], baselines=baselines)