In [1]:
import os
import re
import json
import glob
import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from IPython.display import display

In [2]:
mpl.rcParams['font.family'] = 'DejaVu Sans'
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 500)
pd.set_option('precision', 3)

In [3]:
CHECKPOINT_ROOTS = [
    #'../checkpoints/classification_scratch/resnet.18.original',
    '../checkpoints/classification_pirl/resnet.18.original',
]

assert all([os.path.isdir(ckpt_root) for ckpt_root in CHECKPOINT_ROOTS])

In [4]:
def get_configurations(configs, history):
    
    def get_pretext(configs):
        pretext = configs.get('pretext')
        return pretext if pretext is not None else 'scratch'
    
    d = {
        'data_index': configs.get('data_index'),
        'input_size': configs.get('input_size'),
        'backbone_type': configs.get('backbone_type'),
        'backbone_config': configs.get('backbone_config'),
        'in_channels': configs.get('in_channels'),
            
        'labeled': configs.get('labeled'),
        'smoothing': configs.get('smoothing'),
        'dropout': configs.get('dropout'),
            
        'learning_rate': configs.get('learning_rate'),
        
        'pretext': get_pretext(configs),
        'freeze': len(configs.get('freeze')),
        'blockwise_learning_rates': configs.get('blockwise_learning_rates'),
        'best_epoch': history['epoch'],
            
        'train_loss': history['loss']['train'],
        'valid_loss': history['loss']['valid'],
        'test_loss': history['loss']['test'],
            
        'train_accuracy': history['accuracy']['train'],
        'valid_accuracy': history['accuracy']['valid'],
        'test_accuracy': history['accuracy']['test'],
            
        'train_f1': history['f1']['train'],
        'valid_f1': history['f1']['valid'],
        'test_f1': history['f1']['test'],
    }
    
    d.update({'noise': configs.get('noise', '-')})
    d.update({'rotate': configs.get('rotate', '-')})
    
    return d

In [5]:
data = {}

for ckpt_root in CHECKPOINT_ROOTS:
    
    # Find configuration files recursively
    config_files = glob.glob(os.path.join(ckpt_root, '**/*/configs.json'), recursive=True)
    config_files = [os.path.normpath(p) for p in config_files]
    
    for config_file in config_files:
        
        try:
            # Open configuration file
            with open(config_file, 'r') as fp:
                configs = json.load(fp)
            # Open history file
            ckpt_dir = os.path.dirname(config_file)
            history_file = os.path.join(ckpt_dir, 'best_history.json')
            with open(history_file, 'r') as fp:
                history = json.load(fp)
                
        except FileNotFoundError:
            continue
        
        data[ckpt_dir] = get_configurations(configs, history)

In [6]:
print(f"Total number of experiments: {len(data):,}")

Total number of experiments: 80


In [7]:
df = pd.DataFrame.from_dict(data, orient='index')
df = df.reset_index(drop=True, inplace=False)
df = df.sort_values(by=['labeled', 'data_index'])
display(df)

Unnamed: 0,data_index,input_size,backbone_type,backbone_config,in_channels,labeled,smoothing,dropout,learning_rate,pretext,freeze,blockwise_learning_rates,best_epoch,train_loss,valid_loss,test_loss,train_accuracy,valid_accuracy,test_accuracy,train_f1,valid_f1,test_f1,noise,rotate
9,0,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},44,0.824,1.038,1.031,0.993,0.915,0.916,0.761,0.36,0.359,0.0,True
65,1,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},70,0.685,0.877,0.869,1.0,0.896,0.898,0.778,0.373,0.382,0.0,True
1,2,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},97,0.647,0.822,0.82,0.993,0.909,0.911,0.74,0.38,0.39,0.0,True
17,3,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},75,0.682,0.81,0.808,0.986,0.912,0.913,0.749,0.393,0.396,0.0,True
33,4,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},66,0.667,0.808,0.802,0.993,0.913,0.915,0.761,0.382,0.391,0.0,True
49,5,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},85,0.679,0.863,0.86,0.978,0.896,0.896,0.717,0.383,0.388,0.0,True
57,6,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},69,0.716,0.876,0.871,0.957,0.895,0.897,0.592,0.38,0.383,0.0,True
41,7,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},55,0.705,0.885,0.881,1.0,0.885,0.888,0.778,0.371,0.379,0.0,True
73,8,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},48,0.725,0.9,0.897,0.978,0.887,0.889,0.714,0.383,0.389,0.0,True
25,9,112,resnet,18.original,2,0.001,0.1,0.5,0.001,pirl,0,{},62,0.686,0.96,0.953,0.986,0.875,0.879,0.727,0.379,0.39,0.0,True


# Create Table

In [8]:
pivot_configs = {
    'values': ['test_f1', 'test_accuracy'],
    'index': ['backbone_type', 'backbone_config', 'pretext', 'noise', 'rotate', 'labeled'],
    'aggfunc': ['mean', 'std'],
}

df_pirl = df.loc[df['pretext'] == 'pirl'].copy()
pirl_table = df_pirl.pivot_table(**pivot_configs)
pirl_table.columns.names = ('statistic', 'metric')
pirl_table = pirl_table.unstack(level=['pretext', 'noise', 'rotate'])
display(pirl_table.style.background_gradient(cmap=plt.cm.viridis, axis=0))

df_scratch = df.loc[df['pretext'] == 'scratch'].copy()
scratch_table = df_scratch.pivot_table(**pivot_configs)
scratch_table.columns.names = ('statistic', 'metric')
scratch_table = scratch_table.unstack(level=['pretext', 'noise', 'rotate'])
display(scratch_table.style.background_gradient(cmap=plt.cm.plasma_r, axis=0))

table = pd.concat([pirl_table, scratch_table], axis=1)
table = pd.concat([pirl_table, scratch_table], axis=1)
table.sort_index(axis=1, level=['statistic', 'metric', 'pretext'], inplace=True)
display(table.style.background_gradient(cmap=plt.cm.coolwarm, axis=0))

Unnamed: 0_level_0,Unnamed: 1_level_0,statistic,mean,mean,std,std
Unnamed: 0_level_1,Unnamed: 1_level_1,metric,test_accuracy,test_f1,test_accuracy,test_f1
Unnamed: 0_level_2,Unnamed: 1_level_2,pretext,pirl,pirl,pirl,pirl
Unnamed: 0_level_3,Unnamed: 1_level_3,noise,0.0,0.0,0.0,0.0
Unnamed: 0_level_4,Unnamed: 1_level_4,rotate,True,True,True,True
backbone_type,backbone_config,labeled,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5
resnet,18.original,0.001,0.9,0.385,0.013,0.01
resnet,18.original,0.005,0.932,0.547,0.004,0.02
resnet,18.original,0.01,0.947,0.632,0.002,0.015
resnet,18.original,0.05,0.959,0.727,0.003,0.013
resnet,18.original,0.1,0.964,0.766,0.001,0.014
resnet,18.original,0.25,0.971,0.817,0.001,0.005
resnet,18.original,0.5,0.974,0.857,0.001,0.008
resnet,18.original,1.0,0.976,0.882,0.001,0.009


IndexError: list index out of range

<pandas.io.formats.style.Styler at 0x7fe2e441d320>

Unnamed: 0_level_0,Unnamed: 1_level_0,statistic,mean,mean,std,std
Unnamed: 0_level_1,Unnamed: 1_level_1,metric,test_accuracy,test_f1,test_accuracy,test_f1
Unnamed: 0_level_2,Unnamed: 1_level_2,pretext,pirl,pirl,pirl,pirl
Unnamed: 0_level_3,Unnamed: 1_level_3,noise,0.0,0.0,0.0,0.0
Unnamed: 0_level_4,Unnamed: 1_level_4,rotate,True,True,True,True
backbone_type,backbone_config,labeled,Unnamed: 3_level_5,Unnamed: 4_level_5,Unnamed: 5_level_5,Unnamed: 6_level_5
resnet,18.original,0.001,0.9,0.385,0.013,0.01
resnet,18.original,0.005,0.932,0.547,0.004,0.02
resnet,18.original,0.01,0.947,0.632,0.002,0.015
resnet,18.original,0.05,0.959,0.727,0.003,0.013
resnet,18.original,0.1,0.964,0.766,0.001,0.014
resnet,18.original,0.25,0.971,0.817,0.001,0.005
resnet,18.original,0.5,0.974,0.857,0.001,0.008
resnet,18.original,1.0,0.976,0.882,0.001,0.009


# Line Plots

In [None]:
def plot(table: pd.DataFrame,
         metric: str, pretexts: list, noise: float or str, 
         model_name: tuple = ('vgg', '3a'),
         fig: mpl.figure.Figure = None, **kwargs):
    
    if fig is not None:
        ax = fig.axes[0]
    else:
        fig, ax = plt.subplots(1, 1, figsize=kwargs.get('figsize', (20, 10)))
    
    y_label = ' '.join([c.capitalize() for c in metric.split('_')])
    
    colors = dict(
        scratch='grey',
        denoising='forestgreen',
        rotation='orangered',
        jigsaw='darkkhaki',
        bigan='skyblue',
        pirl='rebeccapurple'
    )
    
    for pretext in pretexts:
        if pretext == 'scratch':
            label = 'no pretraining'
        else:
            label = pretext
            
        label = label + f" ({'.'.join(model_name)})"
        if isinstance(noise, float):
            label = label.rstrip(')') + f', p={noise:.2f})'
        
        if 'color' in kwargs.keys():
            color = kwargs.get('color')
        else:
            color = colors.get(pretext, 'black')
        
        if pretext == 'scratch':
            s = 250
            marker = 'x'
        else:
            s = 100
            marker = '^'
            
        # Multiindex column must be indexed in the following order:
        # (statistic, metric, pretext, noise)
        row_idx = model_name
        col_idx = (metric, pretext, noise)
        idx = table.loc[model_name, ('mean', ) + col_idx].index
        val = table.loc[model_name, ('mean', ) + col_idx].values
        std = table.loc[model_name, ('std' , ) + col_idx].values
        table.loc[model_name, ('mean', ) + col_idx].plot.line(ax=ax, label=label, color=color)
        ax.scatter(idx, val, marker=marker, s=s, color=color)
        ax.fill_between(idx, val-std, val+std, alpha=0.05, color=color)
        
    ax.grid(True)
    ax.legend(loc='lower right', fontsize=20)
    ax.set_xlabel('Labeled Data Proportion', fontsize=25)
    ax.set_ylabel(y_label, fontsize=25)
    ax.tick_params(axis='both', which='both', labelsize=25)
    ax.set_xticks([0.01, 0.05, 0.10 ,0.25, 0.50, 0.75, 1.00])
    for tick in ax.get_xticklabels():
        tick.set_rotation(45)

    x_min, x_max = kwargs.get('x_min', 0.0), kwargs.get('x_max', 1.1)
    y_min, y_max = kwargs.get('y_min'), kwargs.get('y_max')
    ax.set_xlim(x_min, x_max)
    if all([y is not None for y in (y_min, y_max)]):
        ax.set_ylim(y_min, y_max)
    
    # Set title
    title = f"{y_label}"
    ax.set_title(title, fontsize=30)
    
    return fig

In [None]:
for METRIC in ['test_accuracy', 'test_f1']:
    for MODEL_NAME in [('vgg', '3a'), ('vgg', '6a')]:
        fig = plot(table, METRIC, ['pirl'], noise=0.00, model_name=MODEL_NAME, color='royalblue')
        fig = plot(table, METRIC, ['pirl'], noise=0.10, model_name=MODEL_NAME, fig=fig, color='slateblue')
        fig = plot(table, METRIC, ['pirl'], noise=0.25, model_name=MODEL_NAME, fig=fig, color='indianred')
        #fig = plot(table, METRIC, ['denoising'], noise=0.10, model_name=MODEL_NAME, fig=fig, color='palegreen')
        #fig = plot(table, METRIC, ['denoising'], noise=0.25, model_name=MODEL_NAME, fig=fig, color='teal')
        fig = plot(table, METRIC, ['scratch'], noise='-', model_name=MODEL_NAME, y_min=0.3, y_max=1.0, fig=fig)
        plt.show(fig)
        del fig

# Save table

In [None]:
#save_dir = '../tables'
#os.makedirs(save_dir, exist_ok=True)
#table.stack(level=['statistic']).to_csv(os.path.join(save_dir, 'table.csv'), index=True)