In [1]:
import os
import re
import json
import glob
import matplotlib as mpl
import matplotlib.pyplot as plt

import numpy as np
import pandas as pd

from IPython.display import display

In [2]:
mpl.rcParams['font.family'] = 'DejaVu Sans'
pd.set_option('display.max_rows', 100)
pd.set_option('display.max_columns', 500)
pd.set_option('precision', 3)

# Load configurations and history

In [3]:
BACKBONES = ['alexnet.batch_norm', 'vggnet.16.batch_norm', 'resnet.18.original', 'resnet.50.original']

In [4]:
CHECKPOINT_ROOTS = []
for backbone in BACKBONES:
    CHECKPOINT_ROOTS += [f'../checkpoints/wm811k/classification_scratch/{backbone}']
    CHECKPOINT_ROOTS += [f'../checkpoints/wm811k/classification_pirl/{backbone}']

assert all ([os.path.isdir(d) for d in CHECKPOINT_ROOTS])

In [5]:
def get_configurations(configs, history):
    
    def get_pretext(configs):
        pretext = configs.get('pretext')
        return pretext if pretext is not None else 'scratch'
    
    d = {
        'data': configs.get('data'),
        'input_size': configs.get('input_size'),
        'backbone_type': configs.get('backbone_type'),
        'backbone_config': configs.get('backbone_config'),
            
        'label_proportion': configs.get('label_proportion'),
        'augmentation': configs.get('augmentation'),
            
        'learning_rate': configs.get('learning_rate'),
        'weight_decay': configs.get('weight_decay'),
        'balance': configs.get('balance'),
        'optimizer': configs.get('optimizer'),
        
        'pretext': get_pretext(configs),
        
        'best_epoch': history['epoch'],    
        'train_loss': history['loss']['train'],
        'valid_loss': history['loss']['valid'],
        'test_loss': history['loss']['test'],
        'rain_accuracy': history['accuracy']['train'],
        'valid_accuracy': history['accuracy']['valid'],
        'test_accuracy': history['accuracy']['test'],
        'train_f1': history['f1']['train'],
        'valid_f1': history['f1']['valid'],
        'test_f1': history['f1']['test'],
        'train_auprc': history['auprc']['train'],
        'valid_auprc': history['auprc']['valid'],
        'test_auprc': history['auprc']['test'],
    }
    
    d.update(
        {
            'projector_type': configs.get('projector_type'),
            'projector_size': configs.get('projector_size'),
        }
    )
    
    return d

In [16]:
data = {}

for ckpt_root in CHECKPOINT_ROOTS:
    
    # Find configuration files recursively
    config_files = glob.glob(os.path.join(ckpt_root, '**/configs.json'), recursive=True)
    config_files = [os.path.normpath(p) for p in config_files]
    
    for config_file in config_files:
        
        try:
            # Open configuration file
            with open(config_file, 'r') as fp:
                configs = json.load(fp)
            # Open history file
            ckpt_dir = os.path.dirname(config_file)
            history_file = os.path.join(ckpt_dir, 'last_history.json')
            with open(history_file, 'r') as fp:
                history = json.load(fp)
                
        except FileNotFoundError:
            continue
        
        data[ckpt_dir] = get_configurations(configs, history)

In [17]:
print(f"Total number of experiments: {len(data):,}")

Total number of experiments: 672


In [18]:
df = pd.DataFrame.from_dict(data, orient='index')
df = df.reset_index(drop=True, inplace=False)
df = df.sort_values(by=['label_proportion', 'pretext'])
display(df)

Unnamed: 0,data,input_size,backbone_type,backbone_config,label_proportion,augmentation,learning_rate,weight_decay,balance,optimizer,pretext,best_epoch,train_loss,valid_loss,test_loss,rain_accuracy,valid_accuracy,test_accuracy,train_f1,valid_f1,test_f1,train_auprc,valid_auprc,test_auprc,projector_type,projector_size
49,wm811k,96,alexnet,batch_norm,0.01,rotate+crop,0.01,0.001,True,sgd,pirl,100,0.710,0.789,0.790,0.933,0.902,0.901,0.936,0.662,0.673,0.251,0.499,0.490,mlp,128.0
55,wm811k,96,alexnet,batch_norm,0.01,rotate+crop,0.01,0.001,True,sgd,pirl,100,0.673,0.759,0.761,0.955,0.919,0.919,0.954,0.657,0.636,0.256,0.522,0.548,mlp,128.0
61,wm811k,96,alexnet,batch_norm,0.01,rotate+crop,0.01,0.001,True,sgd,pirl,100,0.687,0.747,0.740,0.944,0.922,0.927,0.942,0.697,0.701,0.192,0.516,0.558,mlp,128.0
67,wm811k,96,alexnet,batch_norm,0.01,rotate+crop,0.01,0.001,True,sgd,pirl,100,0.656,0.755,0.751,0.967,0.914,0.915,0.968,0.677,0.690,0.209,0.497,0.489,mlp,128.0
73,wm811k,96,alexnet,batch_norm,0.01,crop,0.01,0.001,True,sgd,pirl,100,0.605,0.675,0.680,0.982,0.949,0.947,0.983,0.717,0.707,0.085,0.527,0.496,mlp,128.0
...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...,...
527,wm811k,96,resnet,50.original,1.00,test,0.01,0.001,True,sgd,scratch,100,0.535,0.604,0.602,1.000,0.976,0.977,1.000,0.874,0.885,0.091,0.695,0.654,,
533,wm811k,96,resnet,50.original,1.00,rotate,0.01,0.001,True,sgd,scratch,100,0.592,0.643,0.637,0.973,0.948,0.951,0.973,0.836,0.836,0.545,0.611,0.559,,
539,wm811k,96,resnet,50.original,1.00,rotate,0.01,0.001,True,sgd,scratch,100,0.597,0.649,0.641,0.970,0.946,0.949,0.970,0.822,0.827,0.560,0.618,0.558,,
545,wm811k,96,resnet,50.original,1.00,rotate,0.01,0.001,True,sgd,scratch,100,0.598,0.646,0.640,0.970,0.945,0.949,0.970,0.826,0.830,0.564,0.572,0.580,,


# Pivot tables

In [21]:
pivot_configs = {
    'values': ['test_auprc'],
    'index': ['backbone_type', 'backbone_config', 'pretext', 'label_proportion', 'augmentation'],
    'aggfunc': ['median', 'std'],
}

#df_denoising = df.loc[df['pretext'] == 'denoising'].copy()
#denoising_table = df_denoising.pivot_table(**pivot_configs)
#denoising_table.columns.names = ('statistic', 'metric')
#denoising_table = denoising_table.unstack(level=['label_proportion'])
#display(denoising_table.style.background_gradient(cmap=plt.cm.Blues, axis=1))

df_pirl = df.loc[df['pretext'] == 'pirl'].copy()
pirl_table = df_pirl.pivot_table(**pivot_configs)
pirl_table.columns.names = ('statistic', 'metric')
pirl_table = pirl_table.unstack(level=['label_proportion'])
display(pirl_table.style.background_gradient(cmap=plt.cm.viridis, axis=1))

df_scratch = df.loc[df['pretext'] == 'scratch'].copy()
scratch_table = df_scratch.pivot_table(**pivot_configs)
scratch_table.columns.names = ('statistic', 'metric')
scratch_table = scratch_table.unstack(level=['label_proportion'])
display(scratch_table.style.background_gradient(cmap=plt.cm.plasma, axis=1))

table = pd.concat([pirl_table, scratch_table], axis=0)
table.sort_index(axis=0, level=['backbone_type', 'backbone_config', 'augmentation', 'pretext'])
table.sort_index(axis=1, level=['statistic', 'metric', 'label_proportion'], inplace=True)
display(table.style.background_gradient(cmap=plt.cm.coolwarm, axis=1))

Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,statistic,median,median,median,median,median,median,std,std,std,std,std,std
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,metric,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,label_proportion,0.01,0.05,0.1,0.25,0.5,1.0,0.01,0.05,0.1,0.25,0.5,1.0
backbone_type,backbone_config,pretext,augmentation,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3
alexnet,batch_norm,pirl,crop,0.517,0.551,0.582,0.609,0.581,0.574,0.032,0.053,0.029,0.012,0.055,0.021
alexnet,batch_norm,pirl,cutout,0.483,0.595,0.611,0.627,0.662,0.649,0.019,0.036,0.033,0.029,0.02,0.024
alexnet,batch_norm,pirl,noise,0.448,0.596,0.564,0.61,0.597,0.611,0.011,0.054,0.042,0.01,0.026,0.015
alexnet,batch_norm,pirl,rotate,0.538,0.573,0.616,0.615,0.582,0.573,0.039,0.063,0.076,0.022,0.014,0.018
alexnet,batch_norm,pirl,rotate+crop,0.519,0.518,0.516,0.544,0.509,0.506,0.037,0.048,0.038,0.061,0.013,0.015
alexnet,batch_norm,pirl,shift,0.539,0.595,0.584,0.663,0.625,0.608,0.034,0.055,0.058,0.062,0.062,0.012
resnet,18.original,pirl,crop,0.548,0.63,0.652,0.677,0.581,0.665,0.02,0.045,0.037,0.031,0.088,0.051
resnet,18.original,pirl,cutout,0.537,0.605,0.616,0.638,0.648,0.617,0.03,0.037,0.043,0.028,0.016,0.03
resnet,18.original,pirl,noise,0.483,0.579,0.614,0.66,0.628,0.633,0.037,0.016,0.017,0.053,0.013,0.042
resnet,18.original,pirl,rotate,0.545,0.572,0.623,0.637,0.611,0.596,0.048,0.052,0.052,0.067,0.022,0.015


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,statistic,median,median,median,median,median,median,std,std,std,std,std,std
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,metric,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,label_proportion,0.01,0.05,0.1,0.25,0.5,1.0,0.01,0.05,0.1,0.25,0.5,1.0
backbone_type,backbone_config,pretext,augmentation,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3
alexnet,batch_norm,scratch,rotate,0.562,0.541,0.5,0.511,0.512,0.529,0.028,0.045,0.02,0.024,0.006,0.019
alexnet,batch_norm,scratch,test,0.425,0.502,0.522,0.527,0.548,0.617,0.039,0.05,0.035,0.051,0.036,0.033
resnet,18.original,scratch,rotate,0.498,0.532,0.517,0.535,0.586,0.585,0.04,0.028,0.047,0.028,0.017,0.009
resnet,18.original,scratch,test,0.436,0.536,0.563,0.559,0.588,0.659,0.029,0.039,0.048,0.017,0.034,0.052
resnet,50.original,scratch,rotate,0.506,0.492,0.525,0.532,0.597,0.569,0.08,0.055,0.031,0.036,0.017,0.015
resnet,50.original,scratch,test,0.454,0.53,0.518,0.598,0.677,0.65,0.051,0.018,0.061,0.017,0.048,0.051
vggnet,16.batch_norm,scratch,rotate,0.55,0.575,0.551,0.547,0.557,0.588,0.041,0.026,0.041,0.051,0.032,0.004


Unnamed: 0_level_0,Unnamed: 1_level_0,Unnamed: 2_level_0,statistic,median,median,median,median,median,median,std,std,std,std,std,std
Unnamed: 0_level_1,Unnamed: 1_level_1,Unnamed: 2_level_1,metric,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc,test_auprc
Unnamed: 0_level_2,Unnamed: 1_level_2,Unnamed: 2_level_2,label_proportion,0.01,0.05,0.1,0.25,0.5,1.0,0.01,0.05,0.1,0.25,0.5,1.0
backbone_type,backbone_config,pretext,augmentation,Unnamed: 4_level_3,Unnamed: 5_level_3,Unnamed: 6_level_3,Unnamed: 7_level_3,Unnamed: 8_level_3,Unnamed: 9_level_3,Unnamed: 10_level_3,Unnamed: 11_level_3,Unnamed: 12_level_3,Unnamed: 13_level_3,Unnamed: 14_level_3,Unnamed: 15_level_3
alexnet,batch_norm,pirl,crop,0.517,0.551,0.582,0.609,0.581,0.574,0.032,0.053,0.029,0.012,0.055,0.021
alexnet,batch_norm,pirl,cutout,0.483,0.595,0.611,0.627,0.662,0.649,0.019,0.036,0.033,0.029,0.02,0.024
alexnet,batch_norm,pirl,noise,0.448,0.596,0.564,0.61,0.597,0.611,0.011,0.054,0.042,0.01,0.026,0.015
alexnet,batch_norm,pirl,rotate,0.538,0.573,0.616,0.615,0.582,0.573,0.039,0.063,0.076,0.022,0.014,0.018
alexnet,batch_norm,pirl,rotate+crop,0.519,0.518,0.516,0.544,0.509,0.506,0.037,0.048,0.038,0.061,0.013,0.015
alexnet,batch_norm,pirl,shift,0.539,0.595,0.584,0.663,0.625,0.608,0.034,0.055,0.058,0.062,0.062,0.012
resnet,18.original,pirl,crop,0.548,0.63,0.652,0.677,0.581,0.665,0.02,0.045,0.037,0.031,0.088,0.051
resnet,18.original,pirl,cutout,0.537,0.605,0.616,0.638,0.648,0.617,0.03,0.037,0.043,0.028,0.016,0.03
resnet,18.original,pirl,noise,0.483,0.579,0.614,0.66,0.628,0.633,0.037,0.016,0.017,0.053,0.013,0.042
resnet,18.original,pirl,rotate,0.545,0.572,0.623,0.637,0.611,0.596,0.048,0.052,0.052,0.067,0.022,0.015


# Comparison against different input sizes

In [None]:
CKPT_ROOTS = [
    '../checkpoints.56/classification_scratch/resnet.18.original/',
    '../checkpoints.84/classification_scratch/resnet.18.original/',
    '../checkpoints/classification_scratch/resnet.18.original/'
]

assert all([os.path.isdir(root) for root in CKPT_ROOTS])

In [None]:
data = {}

for ckpt_root in CKPT_ROOTS:
    
    # Find configuration files recursively
    config_files = glob.glob(os.path.join(ckpt_root, '**/*/configs.json'), recursive=True)
    config_files = [os.path.normpath(p) for p in config_files]
    
    for config_file in config_files:
        
        try:
            # Open configuration file
            with open(config_file, 'r') as fp:
                configs = json.load(fp)
            # Open history file
            ckpt_dir = os.path.dirname(config_file)
            history_file = os.path.join(ckpt_dir, 'best_history.json')
            with open(history_file, 'r') as fp:
                history = json.load(fp)
                
        except FileNotFoundError:
            continue
        
        data[ckpt_dir] = get_configurations(configs, history)

In [None]:
print(f"Total number of experiments: {len(data):,}")

In [None]:
df = pd.DataFrame.from_dict(data, orient='index')
df = df.reset_index(drop=True, inplace=False)
df = df.sort_values(by=['labeled', 'data_index'])
display(df)

In [None]:
pivot_configs = {
    'values': ['test_f1'],
    'index': ['labeled', 'input_size'],
    'aggfunc': ['mean', 'std'],
}

df_scratch = df.loc[df['pretext'] == 'scratch'].copy()
scratch_table = df_scratch.pivot_table(**pivot_configs)
scratch_table.columns.names = ('statistic', 'metric')
scratch_table = scratch_table.unstack(level=['input_size'])
display(scratch_table.style.background_gradient(cmap=plt.cm.Spectral_r, axis=0))

# Line plots

In [None]:
def plot(table: pd.DataFrame,
         metric: str, pretexts: list, noise: float or str, 
         model_name: tuple = ('vgg', '3a'),
         fig: mpl.figure.Figure = None, **kwargs):
    
    if fig is not None:
        ax = fig.axes[0]
    else:
        fig, ax = plt.subplots(1, 1, figsize=kwargs.get('figsize', (20, 10)))
    
    y_label = ' '.join([c.capitalize() for c in metric.split('_')])
    
    colors = dict(
        scratch='grey',
        denoising='forestgreen',
        rotation='orangered',
        jigsaw='darkkhaki',
        bigan='skyblue',
        pirl='rebeccapurple'
    )
    
    for pretext in pretexts:
        if pretext == 'scratch':
            label = 'no pretraining'
        else:
            label = pretext
            
        label = label + f" ({'.'.join(model_name)})"
        if isinstance(noise, float):
            label = label.rstrip(')') + f', p={noise:.2f})'
        
        if 'color' in kwargs.keys():
            color = kwargs.get('color')
        else:
            color = colors.get(pretext, 'black')
        
        if pretext == 'scratch':
            s = 250
            marker = 'x'
        else:
            s = 100
            marker = '^'
            
        # Multiindex column must be indexed in the following order:
        # (statistic, metric, pretext, noise)
        row_idx = model_name
        col_idx = (metric, pretext, noise)
        idx = table.loc[model_name, ('mean', ) + col_idx].index
        val = table.loc[model_name, ('mean', ) + col_idx].values
        std = table.loc[model_name, ('std' , ) + col_idx].values
        table.loc[model_name, ('mean', ) + col_idx].plot.line(ax=ax, label=label, color=color)
        ax.scatter(idx, val, marker=marker, s=s, color=color)
        ax.fill_between(idx, val-std, val+std, alpha=0.05, color=color)
        
    ax.grid(True)
    ax.legend(loc='lower right', fontsize=20)
    ax.set_xlabel('Labeled Data Proportion', fontsize=25)
    ax.set_ylabel(y_label, fontsize=25)
    ax.tick_params(axis='both', which='both', labelsize=25)
    ax.set_xticks([0.01, 0.05, 0.10 ,0.25, 0.50, 0.75, 1.00])
    for tick in ax.get_xticklabels():
        tick.set_rotation(45)

    x_min, x_max = kwargs.get('x_min', 0.0), kwargs.get('x_max', 1.1)
    y_min, y_max = kwargs.get('y_min'), kwargs.get('y_max')
    ax.set_xlim(x_min, x_max)
    if all([y is not None for y in (y_min, y_max)]):
        ax.set_ylim(y_min, y_max)
    
    # Set title
    title = f"{y_label}"
    ax.set_title(title, fontsize=30)
    
    return fig

In [None]:
for METRIC in ['test_accuracy', 'test_f1']:
    for MODEL_NAME in [('vgg', '3a'), ('vgg', '6a')]:
        fig = plot(table, METRIC, ['pirl'], noise=0.00, model_name=MODEL_NAME, color='royalblue')
        fig = plot(table, METRIC, ['pirl'], noise=0.10, model_name=MODEL_NAME, fig=fig, color='slateblue')
        fig = plot(table, METRIC, ['pirl'], noise=0.25, model_name=MODEL_NAME, fig=fig, color='indianred')
        #fig = plot(table, METRIC, ['denoising'], noise=0.10, model_name=MODEL_NAME, fig=fig, color='palegreen')
        #fig = plot(table, METRIC, ['denoising'], noise=0.25, model_name=MODEL_NAME, fig=fig, color='teal')
        fig = plot(table, METRIC, ['scratch'], noise='-', model_name=MODEL_NAME, y_min=0.3, y_max=1.0, fig=fig)
        plt.show(fig)
        del fig

# Save table

In [None]:
#save_dir = '../tables'
#os.makedirs(save_dir, exist_ok=True)
#table.stack(level=['statistic']).to_csv(os.path.join(save_dir, 'table.csv'), index=True)