In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import json
import scipy
import scipy.stats

# Detectron colors
_COLORS = np.array([
    0.000, 0.447, 0.741,
    0.850, 0.325, 0.098,
    0.929, 0.694, 0.125,
    0.494, 0.184, 0.556,
    0.466, 0.674, 0.188
]).astype(np.float32).reshape((-1, 3))

# Random number generator seed
_RNG_SEED = 1

# Fix RNG seeds
random.seed(_RNG_SEED)
np.random.seed(_RNG_SEED)

# Directory where sweep summaries are stored
_DATA_DIR = '../data'

# Max flops constraint
_MAX_F = 0.129
# Max params constraint
_MAX_P = 0.856

In [None]:
def load_sweep(sweep_name):
    """Loads a sweep summary."""
    summary_path = os.path.join(_DATA_DIR, '{}.json'.format(sweep_name))
    with open(summary_path, 'r') as f:
        sweep_summary = json.load(f)
    return sweep_summary

In [None]:
_MIN_P = 0.024

def nw_is_valid_p(job):
    return _MIN_P < job['params'] * 1e-6 and job['params'] * 1e-6 < _MAX_P

In [None]:
def compute_norm_ws(cs, num_bins, c_range):
    """Computes normalized EDF weights."""
    hist, edges = np.histogram(cs, bins=num_bins, range=c_range)
    inds = np.digitize(cs, bins=edges) - 1
    assert np.count_nonzero(hist) == num_bins
    return 1 / hist[inds] / num_bins

In [None]:
# Load ResNet and Vanilla sweeps
sweeps = {
    'Vanilla': load_sweep('Vanilla'),
    'ResNet': load_sweep('ResNet')
}

In [None]:
print('Figure 6\n')

r, c = 1, 2
w, h = 4, 3
fig, axes = plt.subplots(nrows=r, ncols=c, figsize=(c * w, r * h))

#################
# EDF shape
#################
num_bins = 10
dss = ['Vanilla', 'ResNet']

ax = axes[0]
for j, ds in enumerate(dss):
    errs = np.array([job['min_test_top1'] for job in sweeps[ds] if nw_is_valid_p(job)])
    ps = np.array([job['params'] * 1e-6 for job in sweeps[ds] if nw_is_valid_p(job)])
    inds = np.argsort(errs)
    errs, ps = errs[inds], ps[inds]
    ws = compute_norm_ws(ps, num_bins, c_range=(_MIN_P, _MAX_P))
    assert np.isclose(np.sum(ws), 1.0)
    ax.plot(
        errs, np.cumsum(ws),
        color=_COLORS[j], linewidth=2, alpha=0.8, label=ds
    )
ax.grid(alpha=0.4)
ax.set_xlabel('error | params', fontsize=16)
ax.set_ylabel('cumulative prob.', fontsize=16)
ax.set_xlim([5, 17.5])
axes[0].set_xticks([5, 7.5, 10, 12.5, 15, 17.5])
ax.legend(loc='lower right', prop={'size': 14})

###################
# RS efficiency
###################
num_bins = 10
dss = ['Vanilla', 'ResNet']

random.seed(_RNG_SEED)
ks = [2 ** p for p in range(13)]

ax = axes[1]
for j, ds in enumerate(dss):
    errs = np.array([job['min_test_top1'] for job in sweeps[ds] if nw_is_valid_p(job)])
    ps = np.array([job['params'] * 1e-6 for job in sweeps[ds] if nw_is_valid_p(job)])
    inds = np.argsort(errs)
    errs, ps = errs[inds], ps[inds]
    ws = compute_norm_ws(ps, num_bins, c_range=(_MIN_P, _MAX_P))
    assert np.isclose(np.sum(ws), 1.0)
    cum_ws = np.cumsum(ws)
    # Compute min errs for each k
    k_errs = {}
    for k in ks:
        k_errs[k] = []
        n = len(errs) // k
        for s in range(n):
            s_errs = random.choices(population=errs, cum_weights=cum_ws, k=k)
            k_errs[k].append(np.min(s_errs))
    ks_x = np.log2(ks)
    # Plot means and std
    ax.scatter(
        ks_x, [np.mean(k_errs[k]) for k in ks],
        color=_COLORS[j], alpha=0.8, label=ds
    )
    mus = np.array([np.mean(k_errs[k]) for k in ks])
    stds = np.array([np.std(k_errs[k]) for k in ks])
    ax.fill_between(
        ks_x, mus - 2 * stds, mus + 2 * stds,
        color=_COLORS[j], alpha=0.1
    )

ax.grid(alpha=0.4)
ax.set_xlabel('experiment size (log2)', fontsize=16)
ax.set_ylabel('error | params', fontsize=16)
ax.set_ylim([5, 10.0])
ax.set_xlim([-0.5, 12.5])
ax.legend(loc='upper right', prop={'size': 14})

plt.tight_layout();

In [None]:
print('Figure 7\n')

r, c = 1, 2
w, h = 4, 3
fig, axes = plt.subplots(nrows=r, ncols=c, figsize=(c * w, r * h))

#################
# EDFs
#################
random.seed(_RNG_SEED)
ks = [10, 100, 1000, 10000]

ds = 'ResNet'
errs = [job['min_test_top1'] for job in sweeps[ds]]

for i, k in enumerate(ks):
    errs_s = sorted(random.sample(errs, k))
    axes[0].plot(
        errs_s, np.linspace(0, 1, len(errs_s)),
        color=_COLORS[i], linewidth=2, alpha=0.8, label='n = {}'.format(k),
        zorder=(3 - i)
    )

axes[0].grid(alpha=0.4)
axes[0].set_xlabel('error', fontsize=16)
axes[0].set_ylabel('cumulative prob.', fontsize=16)
axes[0].legend(loc='lower right', prop={'size': 14})
axes[0].set_xlim([4.7, 12.3])
axes[0].set_xticks([5, 6, 7, 8, 9, 10, 11, 12])

###################
# KS stats
###################
random.seed(_RNG_SEED)
ks = [2 ** p for p in range(15)]

ds = 'ResNet'
errs = [j['min_test_top1'] for j in sweeps[ds]]
errs_all = random.sample(errs, ks[-1])

ks_stats = []
num_trials = 50

for i, k in enumerate(ks):
    ks_stats_t = []
    for t in range(num_trials):
        errs_s = random.sample(errs, k)
        ks_stat, _p_val = scipy.stats.ks_2samp(errs_all, errs_s)
        ks_stats_t.append(ks_stat)
    ks_stats.append(np.mean(ks_stats_t))

axes[1].axhline(0, color=_COLORS[1], alpha=0.8, linewidth=2.5, linestyle='--')
axes[1].plot(
    np.log2(ks), ks_stats,
    color=_COLORS[0], linewidth=2.5, alpha=0.8
)
axes[1].fill_between(
    np.log2(ks)[7:11], np.zeros(len(ks))[7:11], ks_stats[7:11],
    color=_COLORS[0], alpha=0.15
)
axes[1].set_ylim([-0.035, 0.8])

axes[1].grid(alpha=0.4)
axes[1].set_xlabel('sample size (log2)', fontsize=16)
axes[1].set_ylabel('KS statistic (D)', fontsize=16)
axes[1].set_xticks([0, 2, 4, 6, 8, 10, 12, 14])

plt.tight_layout();