# Sumarize Runs

This is really just used for aggregating the results of different runs (stored in `./summaries` dir).
At this point you should have done the runs you want (hint hint: maybe that means running `./run.sh` for you).

In [68]:
import json
from pathlib import Path
from pprint import pprint

import numpy as np
from tabulate import tabulate

In [69]:
# Lets see what folders exist
dumpdir = Path('summaries')
datasets = set(str(foldername.name.split('_')[0]) for foldername in dumpdir.glob('*'))
pprint(datasets)

{'bbc-news', 'sst2', 'SentEval-CR', 'imdb', '.nomedia', 'enron'}


In [70]:
# We now have to select all instances (per case) across all variations
# So we have one list where all the case-n runs are

In [104]:
def pull_case_runs(casenum):
    runs = {}
    for filename in dumpdir.glob(f'*/case_{casenum}.json'):
        with filename.open("r") as f:
            runs.setdefault('_'.join(str(filename.parent.name).split('_')[:-1]),[]).append(json.load(f))
    return runs
    
runs_0 = pull_case_runs(0)
runs_1 = pull_case_runs(1)
runs_2 = pull_case_runs(2)
runs_3 = pull_case_runs(3)
runs_4 = pull_case_runs(4)

In [108]:
def filter_runs(runs, num_sents:int):
    filtered = {}
    for datasetname, dataset_runs in runs.items():
        for run in dataset_runs:
            if not "config" in run or run['config']['num_sents'] == num_sents:
                filtered.setdefault(datasetname, []).append(run)
                
    assert max(len(v) for v in filtered.values()) == 1
    return {k:v[0] for k, v in filtered.items()}

def summarize(runs):
    # assume data is filtered
    # output format = {dataset: {metricname: (mean, std)}}
    output = {}
    for datasetname, datasetrun in runs.items():
        output[datasetname] = {}
        for metricname, metricvalues in datasetrun.items():
            if metricname == 'config':
                continue
            mean, std = np.mean(metricvalues), np.std(metricvalues)
            output[datasetname][metricname] = (mean, std)
            
    return output

def gettable(runs, rownames = None):
    # give summarized filtered stuff
    # get all datasets
    
    if not rownames:
        rownames = list(range(len(runs)))
        
    datasets = set()
    for datasetruns in runs:
        print(datasetruns.keys())
        datasets = datasets.union(set(datasetruns.keys()))
    datasets = list(datasets)
    print(datasets)
    
    # Each dataset is a column (header is dataset)
    header = datasets[:]
    
    # Each cell in the colum should be 0.34 +- 0.11 or whatever
    # Each row is a element of the runs list
    table = []
    for rowname, run in zip(rownames, runs):
        row = [rowname]
        for dataset in datasets:
            
            try:
                data = run[dataset]['accuracy']
                row.append(f"{data[0]:.3f} ± {data[1]:.3f}")
            except KeyError:
                row.append('N/A')
        table.append(row)
                
    return table, header
            


In [110]:
s0 = summarize(filter_runs(runs_0, 100))
s1 = summarize(filter_runs(runs_1, 100))
s2 = summarize(filter_runs(runs_2, 100))
s3 = summarize(filter_runs(runs_3, 10))
s4 = summarize(filter_runs(runs_4, 100))

table, header = gettable((s0, s2, s1, s3, s4),
                         rownames=['SetFit FT', 'No Contrastive SetFit FT', 'Regular FT', 'LLM Prompting','Constrastive AL' ])

dict_keys(['imdb', 'SentEval-CR'])
dict_keys(['imdb', 'SentEval-CR'])
dict_keys(['imdb', 'SentEval-CR'])
dict_keys(['enron_spam', 'sst2', 'imdb', 'SentEval-CR', 'bbc-news'])
dict_keys(['imdb', 'bbc-news', 'sst2'])
['bbc-news', 'sst2', 'SentEval-CR', 'imdb', 'enron_spam']


In [111]:
print(tabulate(table, header, tablefmt='github'))

|                          | bbc-news      | sst2          | SentEval-CR   | imdb          | enron_spam    |
|--------------------------|---------------|---------------|---------------|---------------|---------------|
| SetFit FT                | N/A           | N/A           | 0.882 ± 0.029 | 0.924 ± 0.026 | N/A           |
| No Contrastive SetFit FT | N/A           | N/A           | 0.886 ± 0.005 | 0.902 ± 0.019 | N/A           |
| Regular FT               | N/A           | N/A           | 0.582 ± 0.054 | 0.836 ± 0.166 | N/A           |
| LLM Prompting            | 0.950 ± 0.000 | 0.930 ± 0.000 | 0.900 ± 0.000 | 0.930 ± 0.000 | 0.820 ± 0.000 |
| Constrastive AL          | 0.974 ± 0.000 | 0.925 ± 0.000 | N/A           | 0.926 ± 0.000 | N/A           |
