In [None]:
%load_ext autoreload
%reload_ext autoreload
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
sns.set()
sns.set_style("darkgrid")
from utils import *
import scipy.misc
import glob
import operator
from utils.metric_tracking import MetricTracker
import os.path


## Read and collect results from a target log directory

Checking the target directory, find all folders (each of which is an experiment), and create a metric tracking object for each of these, wrt train, valid, and test (if they exist)

In [None]:
target_directory = 'log'
folders = glob.glob("{}/*/".format(target_directory))
collected_results = {}
for folder in folders:
    experiment = folder.split('/')[-2]
    collected_results[experiment] = {}
    if os.path.isfile('{}/{}/summary_logs/metrics_training.pt'.format(target_directory, experiment)):
        collected_results[experiment]['training'] = MetricTracker(load=True, 
                                                               path='{}/{}/summary_logs/metrics_training.pt'.format(target_directory, experiment),
                                                               tracker_name='training')
    if os.path.isfile('{}/{}/summary_logs/metrics_validation.pt'.format(target_directory, experiment)):
        collected_results[experiment]['validation'] = MetricTracker(load=True, 
                                                               path='{}/{}/summary_logs/metrics_validation.pt'.format(target_directory, experiment),
                                                               tracker_name='validation')
    if os.path.isfile('{}/{}/summary_logs/metrics_testing.pt'.format(target_directory, experiment)):
        collected_results[experiment]['testing'] = MetricTracker(load=True, 
                                                               path='{}/{}/summary_logs/metrics_testing.pt'.format(target_directory, experiment),
                                                               tracker_name='testing')


## Neaten things up

Often experiments will be named with more infomation than strictly necessary for plotting and this can be untidy. Here we make choices re: neatening nameing wrt experiment names, results to be plotted, etc. 

In [None]:
experiments_to_keep = []
must_have_keywords = []

if len(experiments_to_keep):
    filtered_results = { exp: collected_results[exp] for exp in experiments_to_keep if exp in collected_results.keys()}
elif len(must_have_keywords):
    filtered_results = { exp: collected_results[exp] for exp in collected_results.keys() if all(x in exp for x in must_have_keywords)}
else:
    filtered_results = collected_results
if not len(filtered_results) and len(collected_results):
    print('No keys from filter found, keeping all')
    filtered_results = collected_results
num_experiments = len(filtered_results)
print('{} experiments filtered in. Keeping:'.format(num_experiments))
for key in filtered_results.keys():
    print('\t{}'.format(key))


In [None]:
metrics_to_plot = {
    'Loss (Cross Entropy)':('cross_entropy_mean', [0, 1]),
    'Accuracy':('accuracy_mean', None),
}

linestyles = {
    'training':'-',
    'validation':'--',
    'testing':':',
}

alphas = {
    'training':1.0,
    'validation':0.7,
    'testing':0.5,
}

splits_to_plot = ['training', 'validation', 'testing']

plot_names = {}

colour_type = 'muted'
if colour_type.lower() == 'hls':
    colours = sns.hls_palette(num_experiments, l=0.4, s=1)
elif colour_type.lower() == 'bright':
    colours = sns.color_palette('bright', num_experiments)
elif colour_type.lower() == 'deep':
    colours = sns.color_palette('deep', num_experiments)
elif colour_type.lower() == 'muted':
    colours = sns.color_palette('muted', num_experiments)
elif colour_type.lower() == 'pastel':
    colours = sns.color_palette('pastel', num_experiments)
elif colour_type.lower() == 'cubehelix':
    colours = sns.color_palette('cubehelix', num_experiments)
elif colour_type.lower() == 'blues':
    colours = sns.color_palette('Blues', num_experiments)
elif colour_type.lower() == 'reds':
    colours = sns.color_palette('Reds', num_experiments)
else:
    colours = sns.hls_palette(num_experiments, l=0.4, s=1)
print('Colours chosen:')
sns.palplot(colours)


linewidth = 2.5
columns = 1
rows = int(np.ceil(len(metrics_to_plot)/columns))
save_format = 'png'
print('Going to plot {} rows of {} column(s) and save as a {} in: {}/results.{}'.format(rows, columns, save_format, target_directory, save_format))

## Do the actual plotting here

In [None]:
figure = plt.figure(figsize=(14, 7 * rows))
for axi, (metric_name, (metric_identifier, y_limits)) in enumerate(metrics_to_plot.items()):
    ax = figure.add_subplot(rows, columns, axi+1)
    for expi, (experiment_name, metric_tracker) in enumerate(filtered_results.items()):
        if experiment_name in plot_names:
            experiment_name = plot_names[experiment_name]
        
        for which_split in splits_to_plot:
            if which_split in metric_tracker:
                per_epoch_results = metric_tracker[which_split].collect_per_epoch()
                epochs = per_epoch_results['epochs']
                if metric_identifier in per_epoch_results:
                    ax.plot(
                        epochs,                         
                        per_epoch_results[metric_identifier],
                        color=colours[expi], 
                        label='({}) {} ({:0.4f})'.format(which_split, experiment_name, per_epoch_results[metric_identifier][-1]),
                        alpha=alphas[which_split], 
                        linewidth=linewidth,
                        linestyle=linestyles[which_split],
                    )
    if y_limits is not None:
        ax.set_ylim(y_limits[0], y_limits[1])
    ax.legend(frameon=False, ncol=1)
    ax.set_xlabel('Epochs')
    ax.set_ylabel(metric_name)
    
figure.tight_layout()
figure.savefig('{}/results.{}'.format(target_directory, save_format))
