In [1]:
from IPython.display import display, HTML

display(HTML(data="""
<style>
    div#notebook-container    { width: 95%; }
    div#menubar-container     { width: 65%; }
    div#maintoolbar-container { width: 99%; }
</style>
"""))

In [3]:
%load_ext autoreload
%reload_ext autoreload
%matplotlib widget
import ipywidgets as widgets
import os
os.chdir('../')
print(os.getcwd())
import numpy as np
import matplotlib.pyplot as plt
from matplotlib.colors import ListedColormap
import seaborn as sns
sns.set()
sns.set_style("darkgrid")
from utils import *
import scipy.misc
import glob
import operator
from utils.metric_tracking import MetricTracker
import os.path



/home/antreas/current_research_forge/pytorch-experiments-template


## Read and collect results from a target log directory

Checking the target directory, find all folders (each of which is an experiment), and create a metric tracking object for each of these, wrt train, valid, and test (if they exist)

In [4]:
from collections import defaultdict

def compute_accuracy(logits, targets):
    acc = (targets == logits.argmax(-1)).float().detach().cpu().numpy()
    return float(np.mean(acc)) * 100

target_directory = 'log/'
collected_results = defaultdict(dict)

metrics_to_track = {"cross_entropy": 
                    lambda x, y: torch.nn.CrossEntropyLoss()(x, y).item(),
                    "accuracy":
                    compute_accuracy}

from collections import defaultdict

for subdir, dir, files in os.walk(target_directory):
    for file in files:
        print(file)
        if file.endswith('.pt'):
            experiment_name = subdir.split('/')[1]
            filepath = os.path.join(subdir, file)
            set_name = file.replace('metrics_', '').replace('.pt', '')
            collected_results[experiment_name][f'{set_name}'] = MetricTracker(load=True, 
                                                           path=filepath,
                                                           tracker_name=set_name,
                                                           metrics_to_track=metrics_to_track)
           

metrics_training.pt
metrics_validation.pt
metrics_testing.pt
metrics.png
metrics.png
epoch_96_model_dev.ckpt
snapshot.tar.gz
latest_dev.ckpt


## Neaten things up

Often experiments will be named with more infomation than strictly necessary for plotting and this can be untidy. Here we make choices re: neatening nameing wrt experiment names, results to be plotted, etc. 

In [5]:
def setup_metrics_sets_and_colours(num_experiments):
    metrics_to_plot = {
        'Loss (Cross Entropy)':('cross_entropy_mean', [0, 3]),
        'Accuracy':('accuracy_mean', None),
    }

    linestyles = {
        'training':'-',
        'validation':'--',
        'testing':':',
    }

    alphas = {
        'training':1.0,
        'validation':0.7,
        'testing':0.5,
    }

    splits_to_plot = ['training', 'validation']

    plot_names = {}

    colour_type = 'muted'
    if colour_type.lower() == 'hls':
        colours = sns.hls_palette(num_experiments, l=0.4, s=1)
    elif colour_type.lower() == 'bright':
        colours = sns.color_palette('bright', num_experiments)
    elif colour_type.lower() == 'deep':
        colours = sns.color_palette('deep', num_experiments)
    elif colour_type.lower() == 'muted':
        colours = sns.color_palette('muted', num_experiments)
    elif colour_type.lower() == 'pastel':
        colours = sns.color_palette('pastel', num_experiments)
    elif colour_type.lower() == 'cubehelix':
        colours = sns.color_palette('cubehelix', num_experiments)
    elif colour_type.lower() == 'blues':
        colours = sns.color_palette('Blues', num_experiments)
    elif colour_type.lower() == 'reds':
        colours = sns.color_palette('Reds', num_experiments)
    else:
        colours = sns.hls_palette(num_experiments, l=0.4, s=1)
#     sns.palplot(colours)


    linewidth = 2.5
    columns = 1
    rows = int(np.ceil(len(metrics_to_plot) / columns))
    save_format = 'png'
    print('Going to plot {} rows of {} column(s)'.format(rows, columns))
    return plot_names, metrics_to_plot, splits_to_plot, linestyles, alphas, rows, colour_type, columns, linewidth, colours, save_format

## Do the actual plotting here

In [14]:
experiments_to_keep = []
include_keywords = ['cifar100']
exclude_keywords = []
filtered_results = dict()

for experiment_name, value in collected_results.items():
    if len(experiments_to_keep) != 0:
        if experiment_name in experiments_to_keep:
            filtered_results[experiment_name] = value
            continue
    
    if all([item in experiment_name for item in include_keywords]) and \
    not any([item in experiment_name for item in exclude_keywords]):
        filtered_results[experiment_name] = value
        
        
    
num_experiments = len(filtered_results)
print('{} experiments filtered in. Keeping:'.format(num_experiments))
plot_names, metrics_to_plot, splits_to_plot, linestyles, alphas, rows, colour_type, columns, linewidth, colours, save_format = setup_metrics_sets_and_colours(num_experiments)

for key, value in filtered_results.items():
    print(key)

figure = plt.figure(figsize=(8, 7 * rows))
for axi, (metric_name, (metric_identifier, y_limits)) in enumerate(metrics_to_plot.items()):
    ax = figure.add_subplot(rows, columns, axi+1)
    for expi, (experiment_name, metric_tracker) in enumerate(filtered_results.items()):
        if experiment_name in plot_names:
            experiment_name = plot_names[experiment_name]
            
        for which_split in splits_to_plot:
            if which_split in metric_tracker:
                per_epoch_results = metric_tracker[which_split].collect_per_epoch()
                epochs = per_epoch_results['epochs']
                if metric_identifier in per_epoch_results:
                    ax.plot(
                        epochs,                         
                        per_epoch_results[metric_identifier],
                        color=colours[expi], 
                        label='({}) {} ({:0.4f})'.format(which_split, experiment_name, per_epoch_results[metric_identifier][-1]),
                        alpha=alphas[which_split], 
                        linewidth=linewidth,
                        linestyle=linestyles[which_split],
                    )
    if y_limits is not None:
        ax.set_ylim(y_limits[0], y_limits[1])
    ax.legend(frameon=False, ncol=1)
    ax.set_xlabel('Epochs')
    ax.set_ylabel(metric_name)

figure.canvas.resizable = True
figure.canvas.capture_scroll = True
figure.canvas.toolbar_visible = True
figure.tight_layout()


1 experiments filtered in. Keeping:
Going to plot 2 rows of 1 column(s)
example_resnet_9_cifar100


FigureCanvasNbAgg()

In [15]:
figure.savefig(f'{target_directory}/results.{save_format}')