In [1]:
from meta_opt.experiments import train_standard_opt, train_hgd, train_meta_opt
from meta_opt.experiments import print_stuff_and_load_checkpoint, process_results, bcolors, plot
from meta_opt import DIR

import itertools
import dill as pkl
import optax

# Run

In [39]:
# configuration and seeds for each trial
SEEDS = [0,]  # the length of this list is the number of trials we will run :)
CFG = {
    # training options
    'workload': 'MNIST',
    'num_iters': 12000,
    'eval_every': 200,
    'num_eval_iters': -1,
    'batch_size': 128,
    'reset_every': 3000,

    # experiment options
    'experiment_name': 'mnist_metaopt_sweep',
    'load_checkpoint': True,
    'overwrite': True,  # whether to allow us to overwrite existing checkpoints or throw errors
    'directory': DIR,
}

results = print_stuff_and_load_checkpoint(CFG)  # save to temp var `_results` so that we dont immediately overwrite `results` in case somethings wrong

In [None]:
# FOR SWEEPING HYPERPARAMS OF BASELINES
for s in SEEDS:
    CFG['seed'] = s
    
    # # SGD + momentum + weight decay sweep
    # lrs = [0.01, 0.1, 0.2, 0.4]
    # momentums = [0.0, 0.9, 0.95, 0.99]
    # wds = [0, 1e-5, 1e-4, 1e-3]
    # configs = list(itertools.product(lrs, momentums, wds))
    # for i, (lr, m, wd) in enumerate(configs):
    #     key = f'sgd{lr}+m{m}+wd{wd}'
    #     print(key, f'({i+1}/{len(configs)})')
    #     results[key].append(train_standard_opt(CFG, optax.chain(optax.add_decayed_weights(wd), optax.inject_hyperparams(optax.sgd)(learning_rate=lr, momentum=m))))

    # # adam + weight decay sweep
    # lrs = [1e-4, 4e-4, 1e-3]
    # b1s = [0.9, 0.99]
    # b2s = [0.9, 0.99, 0.999]
    # wds = [0, 1e-5, 1e-4, 1e-3]
    # configs = list(itertools.product(lrs, b1s, b2s, wds))
    # for i, (lr, b1, b2, wd) in enumerate(configs):
    #     key = f'adam({lr},{b1},{b2})+wd{wd}'
    #     print(key, f'({i+1}/{len(configs)})')        
    #     results[key].append(train_standard_opt(CFG, optax.inject_hyperparams(optax.adamw)(learning_rate=lr, b1=b1, b2=b2, weight_decay=wd)))

    # # HGD sweep
    # lrs = [0.01, 0.1, 0.2, 0.4]
    # meta_lrs = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1, 1e0]
    # configs = list(itertools.product(lrs, meta_lrs))
    # for i, (lr, meta_lr) in enumerate(configs):
    #     key = f'hgd({lr},{meta_lr})'
    #     print(key, f'({i+1}/{len(configs)})')
    #     results[key].append(train_hgd(CFG, lr, meta_lr))

    # metaopt_noadam sweep
    lrs = [1e-4, 4e-4, 1e-3, 4e-3, 1e-2]
    # Hs = [32, 8, 1]
    Hs = [32,]
    HHs = [2, 3]
    do_counterfactuals = [True, False,]
    configs = list(itertools.product(lrs, Hs, HHs, do_counterfactuals))
    for i, (lr, H, HH, cf) in enumerate(configs):
        key = f'metaopt_noadam({lr},{H},{HH})+cf{cf}'
        print(key, f'({i+1}/{len(configs)})')
        results[key].append(train_meta_opt(CFG, counterfactual=cf, H=H, HH=HH, meta_optimizer=optax.inject_hyperparams(optax.sgd)(learning_rate=lr)))
    
    # # metaopt sweep
    # lrs = [4e-4,]
    # # Hs = [32, 8, 1]
    # Hs = [64,]
    # HHs = [2, 3, 4]
    # do_counterfactuals = [False,]
    # configs = list(itertools.product(lrs, Hs, HHs, do_counterfactuals))
    # for i, (lr, H, HH, cf) in enumerate(configs):
    #     key = f'metaopt({lr},{H},{HH})+cf{cf}'
    #     print(key, f'({i+1}/{len(configs)})')
    #     results[key].append(train_meta_opt(CFG, counterfactual=cf, H=H, HH=HH, meta_optimizer=optax.inject_hyperparams(optax.adam)(learning_rate=lr, b1=0.9, b2=0.999)))
    
    assert len(results) > 0
    filename = '{}/data/{}_raw.pkl'.format(CFG['directory'], CFG['experiment_name'])
    with open(filename, 'wb') as f:
        pkl.dump(results, f)
        print(f'{bcolors.OKBLUE}{bcolors.BOLD}Saved checkpoint for seed {s} to {filename}{bcolors.ENDC}')

metaopt_noadam(0.0001,32,2)+cfTrue (1/20)
89610 params in the model!
32 params in the controller {'M': 32, 'M_ema': 0}


100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 12000/12000 [03:14<00:00, 61.84it/s, loss=0.01, eval_loss=0.09, M=-0.40301365]


metaopt_noadam(0.0001,32,2)+cfFalse (2/20)
89610 params in the model!
32 params in the controller {'M': 32, 'M_ema': 0}


 50%|███████████████████████████████████████████████████████▌                                                       | 6000/12000 [01:54<01:54, 52.33it/s, loss=0.009, eval_loss=0.095, M=-0.30445707]


metaopt_noadam(0.0001,32,3)+cfTrue (3/20)
89610 params in the model!
32 params in the controller {'M': 32, 'M_ema': 0}


 93%|██████████████████████████████████████████████████████████████████████████████████████████████████████▉        | 11131/12000 [04:10<00:19, 43.48it/s, loss=0.055, eval_loss=0.087, M=-0.3925252]

In [48]:
processed_results = process_results(CFG, results)
# processed_results = pkl.load(open('{}/data/{}_processed.pkl'.format(CFG['directory'], CFG['experiment_name']), 'rb'))

[92m[1mSaved processed results to /Users/evandigiorno/Desktop/meta-opt/data/mnist_metaopt_sweep_processed.pkl[0m


# Print

In [43]:
s = sorted([(k, v['avg'][-1]) for k, v in processed_results['eval_acc'].items()], key=lambda t: t[1])[::-1]
for k, v in s: print(k, '\t', v)

metaopt(0.0004,64,2)+cfTrue 	 0.97626203
metaopt(0.0004,32,4)+cfTrue 	 0.97596157
metaopt(0.0001,8,2)+cfTrue 	 0.9748598
metaopt(0.0004,64,3)+cfTrue 	 0.97475964
metaopt(0.0004,32,2)+cfFalse 	 0.97475964
metaopt(0.001,8,3)+cfTrue 	 0.9744591
metaopt(0.0004,8,3)+cfTrue 	 0.9744591
metaopt(0.0001,32,2)+cfTrue 	 0.974359
metaopt(0.0004,32,3)+cfTrue 	 0.97425884
metaopt(0.0001,32,3)+cfTrue 	 0.97425884
metaopt(0.0001,8,3)+cfTrue 	 0.9740585
metaopt(0.0004,32,2)+cfTrue 	 0.97375804
metaopt(0.001,32,3)+cfTrue 	 0.9734575
metaopt(0.0001,32,2)+cfFalse 	 0.9732572
metaopt(0.0001,8,2)+cfFalse 	 0.9730569
metaopt(0.0004,64,4)+cfTrue 	 0.9729567
metaopt(0.0004,8,2)+cfTrue 	 0.9728566
metaopt(0.001,8,2)+cfFalse 	 0.9723558
metaopt(0.001,32,2)+cfFalse 	 0.9723558
metaopt(0.0004,8,4)+cfFalse 	 0.9720553
metaopt(0.001,32,2)+cfTrue 	 0.9719551
metaopt(0.0004,1,2)+cfFalse 	 0.971855
metaopt(0.0004,8,2)+cfFalse 	 0.9717548
metaopt(0.0001,32,4)+cfTrue 	 0.9717548
metaopt(0.001,1,2)+cfFalse 	 0.9714543
met

# Plot

In [52]:
# ----------------------------------------
# plot a particular set of experiments
# ----------------------------------------
# keys_to_plot = [t[0] for t in s[:25]]

# ----------------------------------------
# OR just plot em all
# ----------------------------------------
keys_to_plot = '.*0.0004,64,.*True'  # specific regex
# keys_to_plot = '.*'  # anything

(fig, ax), Ms = plot(processed_results, keys_to_plot)
# ax[1].set_ylim(0.1, 0.15)
ax[2].set_ylim(0.9, 1.0)
# ax[3].set_ylim(0.5, 0.9)
# ax[4].set_ylim(-0.1, 40)
# ax[5].set_ylim(-0.05, 0.05)
# plt.savefig(f'{DIR}/figs/{CFG['experiment_name']}.pdf')

## Animate
Animate the values taken by the $\{M_h\}_{h=1}^H$ coefficients during training. Each $M_h$ multiplies a disturbance from $h$ training steps ago (i.e. 0 is most recent in this plot).

In [54]:
import matplotlib.animation as animation
from IPython.display import HTML
from copy import deepcopy
import numpy as np

downsample_factor = 200  # how many timesteps to move forward every animation step
ymin, ymax = -0.4, 0.1

anim_data = []  # each entry is a dictionary containing the M values for that animation step
_Ms = {k: (np.array(v[0]), v[1]) for k, v in Ms.items()}
H_max = max([v[1].shape[1] for v in _Ms.values()])
T = CFG['num_iters']
name = CFG['workload']
for t in range(0, T, downsample_factor):
    temp = {}
    for k, (ts, vals) in _Ms.items(): temp[k] = vals[max(0, np.argmax(ts > t) - 1)]
    anim_data.append(temp)

fig = plt.figure()  # initializing a figure in which the graph will be plotted
ax = plt.axes(xlim =(0, H_max), ylim=(ymin, ymax))  # marking the x-axis and y-axis
ax.set_xlabel('number of steps in the past')
ax.set_ylabel('M coefficient')

# initializing a line variable
ls = {}
for k in _Ms.keys():
    ls[k], = ax.plot([], [], lw = 3, label=k)
legend = ax.legend()

# data which the line will contain (x, y)
def init():
    for l in ls.values(): l.set_data([], [])
    return list(ls.values())

def animate(i):
    for k, M in anim_data[i].items():
        x, y = range(0, len(M)), M
        ls[k].set_data(x, y[::-1])
        # line.set_label(i)
    # legend.get_texts()[0].set_text(i * downsample_factor) #Update label each at frame
    ax.set_title(f'timestep #{i * downsample_factor} of meta-opt on {name}')
    return list(ls.values())

anim = animation.FuncAnimation(fig, animate, init_func = init,
                     frames = T // downsample_factor, interval = downsample_factor, blit = True)
plt.close()
h = HTML(anim.to_html5_video())
display(h)

#### 

#### 