In [1]:
from meta_opt.train_loops import train_standard_opt, train_hgd, train_meta_opt
from meta_opt.utils.experiment_utils import make, save_checkpoint, process_results, bcolors, plot, get_final_cparams
from meta_opt import DIR

import re
import matplotlib.pyplot as plt
import numpy as np
import dill as pkl
import optax

# ==================================================
# configuration and seeds for each trial
SEEDS = [5, 6, 7, 8, 9]

NAME = 'mnist_fullbatch_baselines'
CFG = {
    # training options
    'workload': 'MNIST',
    'num_iters': 10000,
    'eval_every': -1,
    'num_eval_iters': -1,
    'batch_size': 512,
    'full_batch': True,
    'reset_every': 500,

    # experiment options
    'experiment_name': NAME,
    'load_checkpoint': False,
    'overwrite': True,  # whether to allow us to overwrite existing checkpoints or throw errors
    'directory': DIR + '/..',
}

def run(seeds, cfg):
    results = make(cfg)
    
    # uncomment the ones to run, with correctly chosen hyperparameters
    for s in seeds:
        CFG['seed'] = s
        print(f'running with seed {s}')
        
        # # ours
        # opt = optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3)
        # results['ours_1e-3'].append(train_meta_opt(CFG, counterfactual=True, H=32, HH=2, meta_optimizer=opt))

        # standard benchmarks
        benchmarks = {
            # 'sgd': optax.inject_hyperparams(optax.sgd)(learning_rate=0.4),
            # 'momentum': optax.chain(optax.add_decayed_weights(1e-4), optax.inject_hyperparams(optax.sgd)(learning_rate=0.1, momentum=0.9)),
            # 'adamw': optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3, b1=0.9, b2=0.999, weight_decay=1e-4),
            'dog': optax.inject_hyperparams(optax.contrib.dog)(0.5),
            'dowg': optax.inject_hyperparams(optax.contrib.dowg)(0.5),
            'dadamw': optax.inject_hyperparams(optax.contrib.dadapt_adamw)(),
            'mechsgd': optax.contrib.mechanize(optax.inject_hyperparams(optax.sgd)(learning_rate=0.4)),
            'mechadamw': optax.contrib.mechanize(optax.inject_hyperparams(optax.adamw)(learning_rate=1e-3, b1=0.9, b2=0.999, weight_decay=1e-4)),
        }
        for k, opt in benchmarks.items(): results[k].append(train_standard_opt(CFG, opt))

        # # other
        # results['hgd'].append(train_hgd(CFG, initial_lr=0.4, hypergrad_lr=1e-4))

        save_checkpoint(CFG, results, checkpoint_name=f'seed {s}')
    processed_results = process_results(CFG, results)
    return processed_results
# ==================================================



In [None]:
processed_results = run(SEEDS, CFG)

using [93m[1mcpu[0m for jax
results will be stored at: [96m[1m/Users/evandogariu/Desktop/meta-opt/notebooks/../data/mnist_fullbatch_baselines_*.pkl[0m
we will [91m[1mNOT[0m try to load experiment checkpoint first
starting the experiment from scratch :)
[91m[1mnote: using full_batch means we will never eval[0m
running with seed 5
89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:11:29.486520: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|█████████████████████████████████████████████████████████████████| 10000/10000 [00:27<00:00, 365.09it/s, loss=0.01, eval_loss=N/A, lr=0.5]2024-05-28 04:11:56.833015: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:11:56.923924: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|████████████████████████████████████████████████████████████████| 10000/10000 [00:28<00:00, 354.68it/s, loss=0.002, eval_loss=N/A, lr=0.5]2024-05-28 04:12:25.083367: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:12:25.188219: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|██████████████████████████████████████████████████████████████████████▉| 9985/10000 [00:27<00:00, 373.03it/s, loss=0, eval_loss=N/A, lr=1]2024-05-28 04:12:52.547256: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
100%|██████████████████████████████████████████████████████████████████████| 10000/10000 [00:27<00:00, 364.99it/s, loss=0, eval_loss=N/A,

89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:12:52.675514: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|██████████████████████████████████████████████████████████████████████| 10000/10000 [00:27<00:00, 366.29it/s, loss=0, eval_loss=N/A, lr=0]2024-05-28 04:13:19.931407: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:13:20.213582: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|██████████████████████████████████████████████████████████████████████▊| 9969/10000 [00:28<00:00, 357.20it/s, loss=0, eval_loss=N/A, lr=0]2024-05-28 04:13:48.653666: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
100%|██████████████████████████████████████████████████████████████████████| 10000/10000 [00:28<00:00, 351.02it/s, loss=0, eval_loss=N/A,

[94m[1mSaved checkpoint seed 5 to /Users/evandogariu/Desktop/meta-opt/notebooks/../data/mnist_fullbatch_baselines_raw.pkl[0m
running with seed 6
89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:13:51.883541: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|████████████████████████████████████████████████████████████████| 10000/10000 [00:26<00:00, 380.74it/s, loss=0.008, eval_loss=N/A, lr=0.5]2024-05-28 04:14:18.099494: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence



89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:14:18.207867: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
100%|████████████████████████████████████████████████████████████████▊| 9972/10000 [00:26<00:00, 398.50it/s, loss=0.002, eval_loss=N/A, lr=0.5]2024-05-28 04:14:44.543589: W tensorflow/core/framework/local_rendezvous.cc:404] Local rendezvous is aborting with status: OUT_OF_RANGE: End of sequence
100%|████████████████████████████████████████████████████████████████| 10000/10000 [00:26<00:00, 378.96it/s, loss=0.002, eval_loss=N/A, l

89610 params in the model!


  0%|                                                                                                                | 0/10000 [00:00<?, ?it/s]2024-05-28 04:14:44.655002: W tensorflow/core/kernels/data/cache_dataset_ops.cc:858] The calling iterator did not fully read the dataset being cached. In order to avoid unexpected truncation of the dataset, the partially cached contents of the dataset  will be discarded. This can happen if you have an input pipeline similar to `dataset.cache().take(k).repeat()`. You should use `dataset.take(k).cache().repeat()` instead.
 40%|████████████████████████████▋                                          | 4036/10000 [00:11<00:15, 375.97it/s, loss=0, eval_loss=N/A, lr=1]

In [None]:
# plot
keys_to_plot = {
    # 'sgd': 'sgd',
    # 'momentum': 'momentum',
    # 'hgd': 'hgd',
    # 'adamw': 'adamw',
    'dog': 'dog',
    'dowg': 'dowg',
    # 'dadamw': 'dadamw',
    # 'ours_1e-3': 'ours',
    # 'mechsgd': 'mechsgd',
    # 'mechadam': 'mechadam'
}
# keys_to_plot = '.*ours.*'

plots_to_make = {
              'loss': 'Train Loss',
              # 'M': 'Learned Coefficients',
              # 'grad_sq_norm': 'Sq Grad Norm',
              # 'proj_grad_sq_norm': 'Proj Sq Grad Norm',
}

# processed_results = pkl.load(open('{}/data/{}_processed.pkl'.format(CFG['directory'], CFG['experiment_name']), 'rb'))
# for b in baselines: processed_results = append_results(processed_results, b)
    
(fig, ax), anim = plot(None, processed_results, keys_to_plot, plots_to_make, 
                       anim_bounds=None, smoothing=None, highlight_baselines=True, fontsize=20, legend_location='upper left')
ax[0].set_ylim(0, 0.005)
# ax[0].set_xlim(5000, 9500)
# ax[1].set_xlim(0, 31)
# ax[1].set_ylim(-0.1, 0.005)
# ax[1].legend(loc='lower right', fontsize=20)
# plt.savefig('{}/figs/{}.pdf'.format(CFG['directory'], 'mnist_fullbatch_simple'))
plt.show()