In [1]:
# handle the system stuff, colab stuff, etc
import os
try:
    from google import colab  # for use in google colab!!
    !git clone https://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@github.com/edogariu/meta-opt
    !pip install -q ./meta-opt
    !pip install -q dill
    # !pip install -q jax[cuda12_pip]==0.4.20 -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html  # for disabling prealloc, see https://github.com/google/jax/discussions/19014
    # !pip install -q tensorflow-text ml_collections clu sentencepiece  # for WMT
    from google.colab import drive
    drive.mount('/content/drive')
    DIR = os.path.abspath("./drive/My Drive/meta-opt")
except: 
    DIR = os.path.abspath(".")
assert os.path.isdir(DIR)

# make sure we have the necessary folders
for subdir in ['data', 'figs', 'datasets']: 
    temp = os.path.join(DIR, subdir)
    if not os.path.isdir(temp): os.mkdir(temp)

# # for the one-time colab setup
# !git clone https://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@github.com/edogariu/meta-opt
# !cp -r "meta-opt" "drive/My Drive/"
# !pip install kora -q  # library from https://stackoverflow.com/questions/62596466/how-can-i-run-notebooks-of-a-github-project-in-google-colab to help get ID
# from kora.xattr import get_id
# fid = get_id(f"{dir_prefix}meta_opt.ipynb")
# print("https://colab.research.google.com/drive/"+fid)

from meta_opt.experiments import train_standard_opt, train_hgd, train_meta_opt
from meta_opt.experiments import print_stuff_and_load_checkpoint, process_results, bcolors

import re
import matplotlib.pyplot as plt
import numpy as np
import dill as pkl
import optax

# Run

In [2]:
# configuration and seeds for each trial
SEEDS = [0,]  # the length of this list is the number of trials we will run :)
CFG = {
    # training options
    'workload': 'MNIST',
    'num_iters': 15000,
    'eval_every': 50,
    'num_eval_iters': -1,
    'batch_size': 128,
    'reset_every': 3000,

    # experiment options
    'experiment_name': 'mnist_main',
    'load_checkpoint': False,
    'overwrite': False,  # whether to allow us to overwrite existing checkpoints or throw errors
    'directory': DIR,
}

results = print_stuff_and_load_checkpoint(CFG)  # save to temp var `_results` so that we dont immediately overwrite `results` in case somethings wrong

using [93m[1mcpu[0m for jax
results will be stored at: [96m[1m/Users/evandigiorno/Desktop/meta-opt/data/mnist_main_*.pkl[0m
we will [91m[1mNOT[0m try to load experiment checkpoint first
starting the experiment from scratch :)


In [None]:
# uncomment the ones to run, with correctly chosen hyperparameters
for s in SEEDS:
    CFG['seed'] = s
    
    # ours
    results['cf_3_adam'].append(train_meta_opt(CFG, counterfactual=True, H=32, HH=3, meta_optimizer=adam_meta_opt))
    results['ncf_3_adam'].append(train_meta_opt(CFG, counterfactual=False, H=32, HH=3, meta_optimizer=adam_meta_opt))

    # standard benchmarks
    benchmarks = {
        # 'sgd': optax.inject_hyperparams(optax.sgd)(learning_rate=0.1),
        # 'sgd_wd': optax.chain(optax.add_decayed_weights(1e-5), optax.inject_hyperparams(optax.sgd)(learning_rate=0.1)),
        # 'momentum': optax.inject_hyperparams(optax.sgd)(learning_rate=0.05, momentum=0.9),
        # 'adam': optax.inject_hyperparams(optax.adam)(learning_rate=1e-3),
        # 'adamw': optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3, b1=0.9, b2=0.999, weight_decay=1e-5),
        # 'rmsprop': optax.inject_hyperparams(optax.rmsprop)(learning_rate=1e-3),
        # 'rsqrt': optax.inject_hyperparams(optax.adamw)(learning_rate=wmt.rsqrt_lr_schedule(0.001, 1000), b1=0.9, b2=0.98, eps=1e-9, weight_decay=1e-5),  # lr schedule for WMT transformer
    }
    for k, opt in benchmarks.items(): results[k].append(train_standard_opt(CFG, opt))

    # other
    # results['hgd'].append(train_hgd(CFG, initial_lr=0.1, hypergrad_lr=1e-4))

    assert len(results) > 0
    filename = '{}/data/{}_raw.pkl'.format(CFG['directory'], CFG['experiment_name'])
    with open(filename, 'wb') as f:
        pkl.dump(results, f)
        print(f'{bcolors.OKBLUE}{bcolors.BOLD}Saved checkpoint for seed {s} to {filename}{bcolors.ENDC}')

# Plot

In [None]:
# ----------------------------------------
# plot a particular set of experiments
# ----------------------------------------
# keys_to_plot = [
#     'sgd_wd',
#     'momentum',
#     # 'adam_0.001',
#     # 'rmsprop_0.001',
#     'scalar_adam_0.001_initial',
#     # 'scalar_0.00004'
#     # 'scalar_ema',
#     # 'diagonal_short',
#     # 'diagonal_ema',
#     ]

# ----------------------------------------
# OR just plot em all
# ----------------------------------------
# keys_to_plot = ''  # specific regex
keys_to_plot = '.*'  # anything

In [6]:
# Plot
processed_results = process_results(CFG, results)

fig, ax = plt.subplots(len(processed_results), 1, figsize=(10, 24))
Ms = {}

for i, stat_key in enumerate(processed_results.keys()):
    ax[i].set_title(stat_key)
    for experiment_name in processed_results[stat_key].keys():
        if (isinstance(keys_to_plot, list) and experiment_name not in keys_to_plot) or (isinstance(keys_to_plot, str) and not re.match(keys_to_plot, experiment_name)): 
            print(f'skipped {experiment_name}')
            continue
        ts, avgs, stds = processed_results[stat_key][experiment_name]['t'], processed_results[stat_key][experiment_name]['avg'], processed_results[stat_key][experiment_name]['std']
        if avgs.ndim == 2:  # how to handle stats that are vectors (such as the Ms for scalar meta-opt)
            Ms[experiment_name] = avgs
            ax[i].plot(ts, avgs.sum(axis=-1), label=experiment_name)
            stds = ((stds ** 2).sum(axis=-1)) ** 0.5
            ax[i].fill_between(ts, avgs.sum(axis=-1) - 1.96 * stds, avgs.sum(axis=-1) + 1.96 * stds, alpha=0.2)
            # for j in range(avgs.shape[1]):
            #     ax[i].plot(ts, avgs[:, j], label=f'{experiment_name} {str(j)}')
            #     ax[i].fill_between(ts, avgs[:, j] - 1.96 * stds[:, j], avgs[:, j] + 1.96 * stds[:, j], alpha=0.2)
        else:
            # if stat_key in ['loss', 'grad_sq_norm']:
            #     n = 20
            #     kernel = np.array([1 / n,] * n)
            #     avgs = np.convolve(avgs, kernel)[n // 2:n // 2 + avgs.shape[0]]
            #     stds = np.convolve(stds ** 2, kernel ** 2)[n // 2:n // 2 + stds.shape[0]] ** 0.5
            ax[i].plot(ts, avgs, label=experiment_name)
            ax[i].fill_between(ts, avgs - 1.96 * stds, avgs + 1.96 * stds, alpha=0.2)
    ax[i].legend()


# ax[1].set_ylim(-0.1, 2.5)
# ax[2].set_ylim(-0.1, 0.7)
# ax[3].set_ylim(0.5, 0.9)
# ax[4].set_ylim(-0.1, 40)
# ax[5].set_ylim(-0.05, 0.05)
# plt.savefig(f'{DIR}/figs/{CFG['experiment_name']}.pdf')

[92m[1mSaved processed results to /Users/evandigiorno/Desktop/meta-opt/data/test_processed.pkl[0m


ValueError: Number of rows must be a positive integer, not 0

<Figure size 1000x2400 with 0 Axes>

## Animate
Animate the values taken by the $\{M_h\}_{h=1}^H$ coefficients during training. Each $M_h$ multiplies a disturbance from $h$ training steps ago (i.e. 0 is most recent in this plot).

In [None]:
import matplotlib.animation as animation
from IPython.display import HTML

for v in Ms.values(): assert v.shape == list(Ms.values())[0].shape

downsample_factor = 100
T, H = v.shape
ymin, ymax = -0.12, 0.012
name = CFG['workload']

fig = plt.figure()  # initializing a figure in which the graph will be plotted
ax = plt.axes(xlim =(0, H), ylim=(ymin, ymax))  # marking the x-axis and y-axis
ax.set_xlabel('number of steps in the past')
ax.set_ylabel('M coefficient')

# initializing a line variable
ls = {}
for k in Ms.keys():
    ls[k], = ax.plot([], [], lw = 3, label=k)
legend = ax.legend()

# data which the line will contain (x, y)
def init():
    for l in ls.values(): l.set_data([], [])
    return list(ls.values())

def animate(i):
    for k, M in Ms.items():
        x, y = range(0, H), M[i * downsample_factor]
        ls[k].set_data(x, y[::-1])
        # line.set_label(i)
    # legend.get_texts()[0].set_text(i * downsample_factor) #Update label each at frame
    ax.set_title(f'timestep #{i * downsample_factor} of meta-opt on {name}')
    return list(ls.values())

anim = animation.FuncAnimation(fig, animate, init_func = init,
                     frames = T // downsample_factor, interval = downsample_factor, blit = True)
plt.close()
h = HTML(anim.to_html5_video())
display(h)

#### 

#### 