In [None]:
%matplotlib inline

import matplotlib
import matplotlib.pyplot as plt
import numpy as np
import random
import os
import copy
import json
import scipy

# Detectron colors
_COLORS = np.array([
    0.000, 0.447, 0.741,
    0.850, 0.325, 0.098,
    0.929, 0.694, 0.125,
    0.494, 0.184, 0.556,
    0.466, 0.674, 0.188
]).astype(np.float32).reshape((-1, 3))

# Random number generator seed
_RNG_SEED = 1

# Fix RNG seeds
random.seed(_RNG_SEED)
np.random.seed(_RNG_SEED)

# Directory where sweep summaries are stored
_DATA_DIR = '../data'

# Max flops constraint
_MAX_F = 0.600
# Max params constraint
_MAX_P = 6.000

In [None]:
def load_sweep(sweep_name):
    """Loads a sweep summary."""
    summary_path = os.path.join(_DATA_DIR, '{}.json'.format(sweep_name))
    with open(summary_path, 'r') as f:
        sweep_summary = json.load(f)
    return sweep_summary

In [None]:
def compute_norm_ws(cs, num_bins, c_range):
    """Computes normalized EDF weights."""
    hist, edges = np.histogram(cs, bins=num_bins, range=c_range)
    inds = np.digitize(cs, bins=edges) - 1
    assert np.count_nonzero(hist) == num_bins
    return 1 / hist[inds] / num_bins

In [None]:
def compute_c_range_mins(sweeps, dss):
    """Computes the complexity range mins."""
    max_min_ps = 0.0
    max_min_fs = 0.0
    for cm in ['params', 'flops']:
        for ds in dss:
            if cm == 'params':
                ps = np.array([job['params'] * 1e-6 for job in sweeps[ds]])
                max_min_ps = max(min(ps), max_min_ps)
            if cm == 'flops':
                fs = np.array([job['flops'] * 1e-9 for job in sweeps[ds]])
                max_min_fs = max(min(fs), max_min_fs)
    return max_min_ps, max_min_fs

In [None]:
def is_valid_p(job, min_p, max_p):
    return min_p < job['params'] * 1e-6 and job['params'] * 1e-6 < max_p

def is_valid_f(job, min_f, max_f):
    return min_f < job['flops'] * 1e-9 and job['flops'] * 1e-9 < max_f

In [None]:
# NAS sweeps
sweeps_nas = {
    'NASNet': load_sweep('NASNet_in'),
    'Amoeba': load_sweep('Amoeba_in'),
    'PNAS': load_sweep('PNAS_in'),
    'ENAS': load_sweep('ENAS_in'),
    'DARTS': load_sweep('DARTS_in')
}

In [None]:
# Standard DS sweeps
sweeps_std = {
    'NASNet': load_sweep('NASNet_in'),
    'DARTS': load_sweep('DARTS_in'),
    'ResNeXt-A': load_sweep('ResNeXt-A_in'),
    'ResNeXt-B': load_sweep('ResNeXt-B_in')
}

In [None]:
# lr wd sweeps
sweeps_lr_wd = {
    'Vanilla': load_sweep('Vanilla_lr-wd_in'),
    'ResNet': load_sweep('ResNet_lr-wd_in'),
    'DARTS': load_sweep('DARTS_lr-wd_in')
}

In [None]:
print('Figure 16a\n')

num_bins = 8
dss = ['NASNet', 'Amoeba', 'PNAS', 'ENAS', 'DARTS']
cols = [0, 1, 4, 2, 3]
cms = ['params', 'flops']

r, c = 1, 2
w, h = 4, 3
fig, axes = plt.subplots(nrows=r, ncols=c, figsize=(c * w, r * h))

min_p, min_f = compute_c_range_mins(sweeps_nas, dss)
max_p, max_f = _MAX_P, _MAX_F

for i, cm in enumerate(cms):
    ax = axes[i]
    for j, ds in enumerate(dss):
        if cm == 'params':
            jobs = [job for job in sweeps_nas[ds] if is_valid_p(job, min_p, max_p)]
            errs = np.array([job['min_test_top1'] for job in jobs])
            ps = np.array([job['params'] * 1e-6 for job in jobs])
            inds = np.argsort(errs)
            errs, ps = errs[inds], ps[inds]
            ws = compute_norm_ws(ps, num_bins, c_range=(min_p, max_p))
        if cm == 'flops':
            jobs = [job for job in sweeps_nas[ds] if is_valid_f(job, min_f, max_f)]
            errs = np.array([job['min_test_top1'] for job in jobs])
            fs = np.array([job['flops'] * 1e-9 for job in jobs])
            inds = np.argsort(errs)
            errs, fs = errs[inds], fs[inds]
            ws = compute_norm_ws(fs, num_bins, c_range=(min_f, max_f))
        assert np.isclose(np.sum(ws), 1.0)
        ax.plot(
            errs, np.cumsum(ws),
            color=_COLORS[cols[j]], linewidth=2, alpha=0.8, label=ds
        )
    ax.set_xlabel('error | {}'.format(cm), fontsize=16)
    ax.grid(alpha=0.4)
    ax.set_ylabel('cumulative prob.', fontsize=16)
    ax.set_xlim([27, 50])
    ax.legend(loc='lower right', prop={'size': 13})

plt.tight_layout();

In [None]:
print('Figure 16b\n')

num_bins = 8
dss = ['NASNet', 'Amoeba', 'PNAS', 'ENAS', 'DARTS']
cols = [0, 1, 4, 2, 3]
cms = ['params', 'flops']

r, c = 1, 2
w, h = 4, 3
fig, axes = plt.subplots(nrows=r, ncols=c, figsize=(c * w, r * h))

random.seed(_RNG_SEED)
ks = [2 ** p for p in range(6)]

min_p, min_f = compute_c_range_mins(sweeps_nas, dss)
max_p, max_f = _MAX_P, _MAX_F

for i, cm in enumerate(cms):
    ax = axes[i]
    for j, ds in enumerate(dss):
        if cm == 'params':
            jobs = [job for job in sweeps_nas[ds] if is_valid_p(job, min_p, max_p)]
            errs = np.array([job['min_test_top1'] for job in jobs])
            ps = np.array([job['params'] * 1e-6 for job in jobs])
            inds = np.argsort(errs)
            errs, ps = errs[inds], ps[inds]
            ws = compute_norm_ws(ps, num_bins, c_range=(min_p, max_p))
        if cm == 'flops':
            jobs = [job for job in sweeps_nas[ds] if is_valid_f(job, min_f, max_f)]
            errs = np.array([job['min_test_top1'] for job in jobs])
            fs = np.array([job['flops'] * 1e-9 for job in jobs])
            inds = np.argsort(errs)
            errs, fs = errs[inds], fs[inds]
            ws = compute_norm_ws(fs, num_bins, c_range=(min_f, max_f))
        assert np.isclose(np.sum(ws), 1.0)
        cum_ws = np.cumsum(ws)
        # Compute min errs for each k
        k_errs = {}
        for k in ks:
            k_errs[k] = []
            n = len(errs) // k
            for s in range(n):
                s_errs = random.choices(population=errs, cum_weights=cum_ws, k=k)
                k_errs[k].append(np.min(s_errs))
        # Plot means and stds
        ax.scatter(
            np.log2(ks), [np.mean(k_errs[k]) for k in ks],
            color=_COLORS[cols[j]], alpha=0.8, label=ds
        )
        mus = np.array([np.mean(k_errs[k]) for k in ks])
        stds = np.array([np.std(k_errs[k]) for k in ks])
        ax.fill_between(
            np.log2(ks), mus - 2 * stds, mus + 2 * stds,
            color=_COLORS[cols[j]], alpha=0.05
        )
    ax.set_ylabel('error | {}'.format(cm), fontsize=16)
    ax.grid(alpha=0.4)
    ax.set_xlabel('experiment size (log2)', fontsize=16)
    ax.set_ylim([27, 50])
    ax.legend(loc='upper right', prop={'size': 13})

plt.tight_layout();

In [None]:
print('Figure 16c\n')

num_bins = 5
dss = ['NASNet', 'DARTS', 'ResNeXt-A', 'ResNeXt-B']
cols = [0, 3, 1, 2]
cms = ['params', 'flops']

r, c = 1, 2
w, h = 4, 3
fig, axes = plt.subplots(nrows=r, ncols=c, figsize=(c * w, r * h))

min_p, min_f = compute_c_range_mins(sweeps_std, dss)
max_p, max_f = _MAX_P, _MAX_F

for i, cm in enumerate(cms):
    ax = axes[i]
    for j, ds in enumerate(dss):
        if cm == 'params':
            jobs = [job for job in sweeps_std[ds] if is_valid_p(job, min_p, max_p)]
            errs = np.array([job['min_test_top1'] for job in jobs])
            ps = np.array([job['params'] * 1e-6 for job in jobs])
            inds = np.argsort(errs)
            errs, ps = errs[inds], ps[inds]
            ws = compute_norm_ws(ps, num_bins, c_range=(min_p, max_p))
        if cm == 'flops':
            jobs = [job for job in sweeps_std[ds] if is_valid_f(job, min_f, max_f)]
            errs = np.array([job['min_test_top1'] for job in jobs])
            fs = np.array([job['flops'] * 1e-9 for job in jobs])
            inds = np.argsort(errs)
            errs, fs = errs[inds], fs[inds]
            ws = compute_norm_ws(fs, num_bins, c_range=(min_f, max_f))
        assert np.isclose(np.sum(ws), 1.0)
        ax.plot(
            errs, np.cumsum(ws),
            color=_COLORS[cols[j]], linewidth=2, alpha=0.8, label=ds
        )
    ax.set_xlabel('error | {}'.format(cm), fontsize=16)
    ax.grid(alpha=0.4)
    ax.set_ylabel('cumulative prob.', fontsize=16)
    ax.set_xlim([27, 50])
    ax.legend(loc='lower right', prop={'size': 13})

plt.tight_layout();

In [None]:
print('Figure 16d\n')

r, c = 1, 3
w, h = 4, 3
fig, axes = plt.subplots(nrows=r, ncols=c, figsize=(c * w, r * h))

hps = ['base_lr', 'wd']
lbs = ['learning rate (log10)', 'weight decay (log10)']
dss = ['Vanilla', 'ResNet', 'DARTS']

def_pt = [5 * 1e-2, 5 * 1e-5]
def_pt_log = np.log10(def_pt)

for j, ds in enumerate(dss):
    ax = axes[j] if r == 1 else axes[i, j]
    sweep = sweeps_lr_wd[ds]
    xs = [job['optim'][hps[0]] for job in sweep]
    ys = [job['optim'][hps[1]] for job in sweep]
    # Use log10 scale
    xs_log = np.log10(xs)
    ys_log = np.log10(ys)
    # Compute relative ranks
    errs = [job['min_test_top1'] for job in sweep]
    ranks = np.argsort(np.argsort(errs))
    ranks += 1
    ranks_rel = ranks / (len(ranks))
    # Plot relative ranks
    s = ax.scatter(xs_log, ys_log, c=ranks_rel, alpha=0.4, cmap='viridis', rasterized=True)
    ax.set_xlabel(lbs[0], fontsize=16)
    if j == 0:
        ax.set_ylabel('{}'.format(lbs[1]), fontsize=16)
    xlim_log = np.log10([0.001, 1.0])
    ylim_log = np.log10([0.00001, 0.01])
    ax.set_xlim(xlim_log)
    ax.set_ylim(ylim_log)
    ax.grid(alpha=0.4)
    # Show default setting
    def_pt_alpha = 0.8
    pr_col = _COLORS[1]
    ax.scatter(def_pt_log[0], def_pt_log[1], color=pr_col, alpha=def_pt_alpha)
    ax.plot(
        np.linspace(xlim_log[0], def_pt_log[0], 10), [def_pt_log[1] for _ in range(10)],
        color=pr_col, alpha=def_pt_alpha, linestyle='--', linewidth=2.5
    )
    ax.plot(
        [def_pt_log[0] for _ in range(10)], np.linspace(ylim_log[0], def_pt_log[1], 10),
        color=pr_col, alpha=def_pt_alpha, linestyle='--', linewidth=2.5
    )
    ax.set_title(ds, fontsize=16)

fig.colorbar(s, ax=axes.ravel().tolist());