In [None]:
import matplotlib.pyplot as plt
from matplotlib.pyplot import cm
import os
import pandas as pd
import json

In [None]:
DIR_NAME = 'metatest_outputs'
files = sorted(os.listdir(DIR_NAME))

In [None]:
#files

In [None]:
experiments = {}
for f in files:
    exp_nr, testset = f.split('_')[0], f.split('_')[1].split('.')[0]
    if experiments.get(exp_nr, False):
        experiments[exp_nr].append((testset, f))
    else:
        experiments[exp_nr] = [(testset, f)]

In [None]:
#experiments

In [None]:
meta_train_filenames = {
    '1': 'boolq_cb_csqa_imdb_mrpc-mnli_qqp_sst_wgrande',
    '3': 'cqa_csqa_hswag_siqa_wgrande-argument_imdb_mnli_mrpc_scitail',
    '4': 'cb_mnli_rte_scitail_sick-argument_imdb_mrpc_qqp',
    '5': 'boolq_mnli_qqp_sst_wgrande-argument_imdb_mrpc_scitail',
    '6': 'argument_boolq_cb_cqa_csqa_hswag_imdb_mnli_mrpc_qqp_rte_scitail_sick_siqa_sst_wgrande-argument_imdb_mrpc_scitail',
    '7': 'argument_boolq_cb_cqa_csqa_hswag_imdb_mnli_mrpc_qqp_rte_scitail_sick_siqa_sst_wgrande-argument_imdb_mnli_mrpc_scitail',
    '8': 'argument_boolq_cb_cqa_csqa_hswag_imdb_mnli_mrpc_qqp_rte_scitail_sick_siqa_sst_wgrande-argument_imdb_mrpc_qqp',
    '9': 'argument_boolq_cb_cqa_csqa_hswag_imdb_mnli_mrpc_qqp_rte_scitail_sick_siqa_sst_wgrande-mnli_qqp_sst_wgrande',
    'A': 'mnli_scitail-rte_sick',
    'B': 'mnli_sick-cb_rte',
    'C': 'cb_rte-mnli_scitail'
}

In [None]:
exp_setup = {}
for e, f in meta_train_filenames.items():
    exp_setup[e] = {
        "adapters": f.split('-')[0].split('_'),
        "metatrain": f.split('-')[1].split('_')
    }
#exp_setup

In [None]:
with open("results/exp_setup.json", "w") as f:
    f.write(json.dumps(exp_setup))

In [None]:
records = []
for exp_nr, exp_list in experiments.items():
    for dataset, filename in exp_list:
        with open(DIR_NAME + '/' + filename, "r") as f:
            r = json.load(f)
        for k, v in r.items():
            records.append({
                "exp_nr": exp_nr, 
                "exp_id": exp_nr[0],
                "exp_variant": 'a' if len(exp_nr) == 1 else exp_nr[1],
                "metatest": dataset, 
                "K": k, 
                "avg": v[0], 
                "std": v[1]
            })

In [None]:
results = pd.DataFrame.from_records(records)
results

In [None]:
results.to_csv('results/metatest_results.csv', index=False)

In [None]:
results_dict = {}
for exp_nr, exp_list in experiments.items():
    exp_id = exp_nr[0]
    exp_variant = 'a' if len(exp_nr) == 1 else exp_nr[1]
    for dataset, filename in exp_list:
        with open(DIR_NAME + '/' + filename, "r") as f:
            r = json.load(f)
        if results_dict.get((exp_id, dataset), False):
            results_dict[(exp_id, dataset)][exp_variant] = r
        else:
            results_dict[(exp_id, dataset)] = {exp_variant: r}
#results_dict

In [None]:
datasets = results['metatest'].unique()
exp_ids = results['exp_id'].unique()

In [None]:
# Print all results in somewhat formatted table
print("RESULTS BY EXPERIMENT AND DATASET")
print('=' * 113)
print(" "*14, "{:33}{:33}{:33}".format("Hyperparam set A", "Hyperparam set B", "Hyperparam set C"))
print("Exp Dataset   ","{:10}{:10}{:10}   ".format('K=2', 'K=4', 'K=8') * 3)

print('=' * 113)
for exp_id in exp_ids:
    id_printed = False
    for dataset in datasets:
       # print(exp_id, dataset)
        r = results_dict.get((exp_id, dataset), None)
        if r != None:
            print("{:4}{:8} | ".format(exp_id if not id_printed else '', dataset), end='')
            id_printed = True
            for v in ['a', 'b', 'c']:
                values = r.get(v, None)
                if values == None:
                    print(" " * 30, "| ", end='')
                else:
                    for k in ['2', '4', '8']:
                        avg, std = values.get(k, (-1, -1))
                        if avg>0:
                            print("{:1.2f}±{:1.2f} ".format(avg, std), end='')
                        else:
                            print("    -     ", end='')
                    print(" | ", end='')
            print()
    print('-' * 113)

In [None]:
# Print all results in somewhat formatted table
print("RESULTS BY DATASET AND EXPERIMENT")
print('=' * 113)
print(" "*14, "{:33}{:33}{:33}".format("Hyperparam set A", "Hyperparam set B", "Hyperparam set C"))
print("Dataset  Exp  ","{:10}{:10}{:10}   ".format('K=2', 'K=4', 'K=8') * 3)

print('=' * 113)
for dataset in datasets:
    id_printed = False
    for exp_id in exp_ids:
        r = results_dict.get((exp_id, dataset), None)
        if r != None:
            print("{:9}{:3} | ".format(dataset if not id_printed else '', exp_id), end='')
            id_printed = True
            for v in ['a', 'b', 'c']:
                values = r.get(v, None)
                if values == None:
                    print(" " * 30, "| ", end='')
                else:
                    for k in ['2', '4', '8']:
                        avg, std = values.get(k, (-1, -1))
                        if avg>0:
                            print("{:1.2f}±{:1.2f} ".format(avg, std), end='')
                        else:
                            print("    -     ", end='')
                    print(" | ", end='')
            print()
    print('-' * 113)

In [None]:
marker=['o', 's', '*', '+', 'D', 'v', '<', '^', '>', '.', 'p', 'P']
marker_map = dict(zip(exp_ids, marker))

for dataset in datasets:
        
    setup = exp_setup[exp_id]
    exp_results = results.loc[results['metatest'] ==  dataset]    
    fig, ax = plt.subplots(1,3, sharex=True, sharey=True, figsize=(12,4))
    fig.suptitle('K-shot performance on task: ' + dataset, fontsize=14, fontweight='bold')
       
    for i, v in enumerate(['a', 'b', 'c']):
        
        for exp_id in exp_ids:

            task_df = exp_results.loc[(exp_results['exp_variant'] ==  v) & (exp_results['exp_id'] == exp_id)]
            
            accs = task_df['avg'].values
            stds = task_df['std'].values
            ks = task_df['K'].values
            
            ax[i].plot(ks, accs,  marker=marker_map[exp_id], markeredgecolor='k', markersize=8, label='exp '+exp_id)
            #ax[i].fill_between(ks, [m-s for m,s in zip(accs, stds)], [m+s for m,s in zip(accs, stds)], alpha=0.2, color='red')
            ax[i].set_title(v.upper(), y=1.0, pad=-14)
            ax[i].set_xlabel('K')
        
    plt.xticks(['2', '4', '8'])   
    #plt.xlabel('k')
    plt.legend(exp_ids)
    ax[0].set_ylabel('Mean accuracy')
    plt.ylim(0.30, 0.90)
    ax[2].legend(loc='center left', bbox_to_anchor=(1, 0.5))
    plt.savefig('results/' + dataset + '.png', facecolor='white')
    plt.show()
    
