In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import os
import json
import seaborn as sns

# Detectron colors
_COLORS = np.array([
    0.000, 0.447, 0.741,
    0.850, 0.325, 0.098,
    0.929, 0.694, 0.125,
    0.494, 0.184, 0.556,
    0.466, 0.674, 0.188
]).astype(np.float32).reshape((-1, 3))

# Directory where sweep summaries are stored
_DATA_DIR = '../data'

In [None]:
def compute_norm_ws(cs, num_bins, c_range):
    """Computes normalized EDF weights."""
    hist, edges = np.histogram(cs, bins=num_bins, range=c_range)
    inds = np.digitize(cs, bins=edges) - 1
    assert np.count_nonzero(hist) == num_bins
    return 1 / hist[inds] / num_bins

In [None]:
def load_sweep(sweep_name):
    """Loads a sweep summary."""
    summary_path = os.path.join(_DATA_DIR, '{}.json'.format(sweep_name))
    with open(summary_path, 'r') as f:
        sweep_summary = json.load(f)
    return sweep_summary

In [None]:
# Source: http://kaiminghe.com/ilsvrc15/ilsvrc2015_deep_residual_learning_kaiminghe.pdf
point_ests = {
    'ResNet': 3.57,
    'VGG': 7.3,
    'ZFNet': 11.7,
    'AlexNet': 16.4
}

In [None]:
# Source: https://github.com/facebookresearch/ResNeXt
curve_ests_fs = {
    'ResNet': [(8.0, 24), (16.0, 22.5), (23.0, 22.1)],
    'ResNeXt': [(8.0, 22.1), (16.0, 21.1), (31.0, 20.4)]
}

In [None]:
sweeps = {
    'ResNet-B': load_sweep('ResNet-B'),
    'ResNeXt-B': load_sweep('ResNeXt-B')
}

In [None]:
_MIN_P = 0.023
_MAX_P = 0.856

def is_valid_p(job):
    return _MIN_P < job['params'] * 1e-6 and job['params'] * 1e-6 < _MAX_P

In [None]:
print('Figure 1\n')

r, c = 1, 3
w, h = 3, 3
fig, axes = plt.subplots(
    nrows=r, ncols=c,
    figsize=(c * w, r * h),
    gridspec_kw={'width_ratios':[w, w, w]}
)

title_font_size = 16
axis_font_size = 15
tick_font_size = 12
legend_font_size = 12.5

##########################################
# Point estimates
##########################################
xs = list(point_ests.keys())[::-1]
ys = list(point_ests.values())[::-1]

ax = axes[0]
ax = sns.barplot(x=xs, y=ys, palette=sns.color_palette("RdBu_r", 6, desat=1.0), ax=ax, alpha=1.0)
ax.set_title('(a) point estimates', fontsize=title_font_size)

ax.set_ylim([0, 17.5])
ax.grid(alpha=0.4)
ax.set_ylabel('error', fontsize=axis_font_size)
ax.tick_params(axis='x', labelsize=tick_font_size, rotation=-20)

##########################################
# Curve estimates
##########################################
ms = ['ResNet', 'ResNeXt']

ax = axes[1]
for i, m in enumerate(ms):
    xs = [x for (x, y) in curve_ests_fs[m]]
    ys = [y for (x, y) in curve_ests_fs[m]]
    ax.plot(
        xs, ys, label=m,
        color=_COLORS[1 - i], alpha=0.8, linewidth=2.5,
        marker='o', markersize=8
    )

ax.grid(alpha=0.4)
ax.set_xlabel('complexity', fontsize=axis_font_size)
ax.set_ylabel('error', fontsize=axis_font_size)
ax.legend(loc='upper right', prop={'size' : legend_font_size})
ax.set_title('(b) curve estimates', fontsize=title_font_size)
ax.set_xlim([5, 35])
ax.set_ylim([20, 25])

ax.set_xticks([5, 10, 15, 20, 25, 30, 35], ['5.0', '10.0', '15.0', '20.0', '25.0', '30.0', '35.0'])
ax.set_yticks([20, 21, 22, 23, 24, 25], ['20.0', '21.0', '22.0', '23.0', '24.0', '25.0'])

##########################################
# Distribution estimates
##########################################
dss = ['ResNet-B', 'ResNeXt-B']
lbs = ['ResNet', 'ResNeXt']

ax = axes[2]
for i, ds in enumerate(dss):
    errs = np.array([job['min_test_top1'] for job in sweeps[ds] if is_valid_p(job)])
    ps = np.array([job['params'] * 1e-6 for job in sweeps[ds] if is_valid_p(job)])
    inds = np.argsort(errs)
    errs, ps = errs[inds], ps[inds]
    ws = compute_norm_ws(ps, num_bins=40, c_range=(_MIN_P, _MAX_P))
    assert np.isclose(np.sum(ws), 1.0)
    ax.plot(
        errs, np.cumsum(ws),
        color=_COLORS[1 - i], linewidth=2.5, alpha=0.8, label=lbs[i]
    )

ax.grid(alpha=0.4)
ax.set_title('(c) distribution estimates', fontsize=title_font_size)
ax.set_xlabel('error | complexity', fontsize=axis_font_size)
ax.set_ylabel('cumulative prob.', fontsize=axis_font_size)
ax.set_xlim([4.5, 12.5])
ax.legend(loc='lower right', prop={'size' : legend_font_size})

plt.tight_layout();