In [None]:
import os
from glob import glob
import pickle
import numpy as np
import pandas as pd
import re

from IPython.display import display

In [None]:
class Aggregator:
    def __init__(self, path, task, logname="log.log", flagname='log_flags.pkl', verbose=False):
        self.path = path
        self.task = task
        
        self.flagfiles = [file for path, subdir, files in os.walk(os.path.join(path, task))
                          for file in glob(os.path.join(path, flagname))]
        if verbose:
            print(self.flagfiles)
        # logs = [file for path, subdir, files in os.walk(os.path.join(path, task))
        #    for file in glob(os.path.join(path, logname))]
        self.logfiles = [f.replace(flagname, logname) for f in self.flagfiles]

        self.rename = {'binarize_MNIST': 'binarized',
                       'normalise_fb': 'Nfb',
                       'num_glimpses': 'glimpses',
                       'num_classes_kn': 'KK',
                       'num_uk_test': 'UU', 
                       'num_uk_test_used': 'UU used', 
                       'num_uk_train': 'KU',
                       'scale_sizes': 'scales',
                       'size_z': 'z size',
                       'uk_cycling': 'cycl',
                       'z_B_center': 'z center (B)',
                       'z_dist': 'z dist'
                      }
            
    def _rename(self, columns):
        return [self.rename[c] if (c in self.rename.keys()) else c for c in columns]
    
    def _parse_results(self, file, keyword='TEST: '):
        results = {}
        with open(file, 'r') as f:
            text = f.read()
            test_log = re.findall('(?<={}).*(?=\n)'.format(keyword), text)
            if test_log:
                final = test_log[-1].replace(':', '').split(' ')
                for i in range(len(final) // 2):
                    name = final[2*i].split('/')[-1]
                    value = final[2*i + 1]
                    results[name] = value
        return results
        
    def get_overview(self, param_cols=None, metrics=None, groupby='glimpses', sortby=None, incl_last_valid=False):
        rs = []
        for log, flag in zip(self.logfiles, self.flagfiles):
            with open(flag, 'rb') as f:
                params = pickle.load(f)

            results = self._parse_results(log)
            if incl_last_valid:
                results_valid = self._parse_results(log, keyword='VALID: ')
                results.update({'val_' + k: v for k, v in results_valid.items()})
                
            exp_name = log.split('\\')[-2]
            if results:
                results.update(params)
                results['exp_name'] = exp_name
                rs.append(results)

        if not rs:
            return

        df = pd.DataFrame(rs)  
        # df = df.set_index('name')
        df['pre-train'] = df['pre_train_policy'] + df['pre_train_epochs'].astype(str)
        df.columns = self._rename(df.columns)
        df['scales'] = df['scales'].apply(lambda v: '{}x{}'.format(len(v), v[0]))
        self.available_columns = sorted(df.columns)
        
        if param_cols is not None:
            df = df.set_index(param_cols)
    
        if metrics is not None:
            if incl_last_valid:
                metrics += ['val_' + m for m in metrics]
            df = df[metrics + [groupby]]
            
        df = df.pivot(columns=groupby).swaplevel(axis=1).sort_index(axis=1, level=0, sort_remaining=False)

        if sortby is not None:
            df = df.sort_values(sortby, ascending=False)

        return df

In [None]:
PATH = os.path.join('logs')
print(os.listdir(PATH))

# MNIST

In [None]:
params = ['planner', 'scales', 'pre-train', 'z size', 'z dist', 'z center (B)', 'Nfb']
metrics = ['acc', 'f1', 'loss', 'T']

mnist = Aggregator(PATH, 'MNIST/rl', verbose=False)
df = mnist.get_overview(params, metrics, sortby=(7, 'f1'))
df

# MNIST_UK

In [None]:
params = ['planner', 'scales', 'pre-train', 'z size', 'z dist', 'z center (B)', 'Nfb', 
          'KK', 'KU', 'UU', 'UU used', 'cycl']
metrics = ['f1', 'acc', 'acc_kn', 'acc_uk', 'loss', 'T', 'pct_noDecision']

mnist_uk = Aggregator(PATH, 'MNIST_UK', verbose=False)
df = mnist_uk.get_overview(params, metrics, sortby=(7, 'f1'))
df

# MNIST_OMNI_notMNIST

In [None]:
params = ['planner', 'scales', 'pre-train', 'z size', 'z dist', 'z center (B)', 'Nfb', 
          'uk_pct', 'KK', 'KU', 'UU', 'binarized']
metrics = ['f1', 'acc', 'acc_kn', 'acc_uk']

mnist_omni_notmnist = Aggregator(PATH, 'MNIST_OMNI_notMNIST/rl')
df = mnist_omni_notmnist.get_overview(params, metrics, sortby=(7, 'f1'), incl_last_valid=True)
df