In [1]:
# handle the system stuff, colab stuff, etc
import os
DIR = os.path.abspath("../")

# make sure we have the necessary folders
for subdir in ['data', 'figs', 'datasets']: 
    temp = os.path.join(DIR, subdir)
    if not os.path.isdir(temp): os.mkdir(temp)

from meta_opt.train_loops import train_standard_opt, train_hgd, train_meta_opt
from meta_opt.utils.experiment_utils import make, save_checkpoint, process_results, bcolors, plot, get_final_cparams
import meta_opt.configs as configs

import re
import matplotlib.pyplot as plt
import numpy as np
import dill as pkl
import optax

# ==================================================
# configuration and seeds for each trial
SEEDS = [0,]

NAME = 'wmt_yeet'
CFG = {
    # training options
    'workload': 'WMT',
    'num_iters': 5000,
    'eval_every': 200,
    'num_eval_iters': -1,
    'batch_size': 32,
    'full_batch': False,
    'reset_every': int(1e9),

    # experiment options
    'experiment_name': NAME,
    'load_checkpoint': False,
    'overwrite': True,  # whether to allow us to overwrite existing checkpoints or throw errors
    'directory': DIR,
}

In [None]:
from time import perf_counter
from collections import defaultdict
import tqdm
import jax

from meta_opt.nn import reset_model, train_step, eval
from meta_opt.workloads import get_workload
from meta_opt.utils.pytree_utils import pytree_sq_norm
from meta_opt.utils.experiment_utils import get_opt_hyperparams

# -------------------------------------------------------------------------------------------------
# ------------------------------ Standard Optax Optimizers ----------------------------------------
# -------------------------------------------------------------------------------------------------

def train_standard_opt(cfg, optimizer):
    tstate, train_ds, test_ds, rng, args = get_workload(cfg, optimizer)

    stats = defaultdict(dict)
    args['optimizer_args'] = get_opt_hyperparams(tstate.opt_state)
    args['optimizer_name'] = 'standard'
    stats['args'] = args

    t0 = perf_counter()
    last_eval_step = None
    pbar = tqdm.tqdm(train_ds.as_numpy_iterator(), total=args['num_iters'])
    for t, batch in enumerate(pbar):

        if t % args['reset_every'] == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = train_step(tstate, batch)

        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % args['eval_every'] == 0 and t != 0:
            for k, v in eval(tstate, test_ds.as_numpy_iterator()).items(): s[f'eval_{k}'] = v
            s['param_sq_norm'] = pytree_sq_norm(tstate.params)
            s['grad_sq_norm'] = pytree_sq_norm(grads)
            last_eval_step = t

        stats[t] = s
        pbar.set_postfix({'loss': round(s['loss'].item(), 3), 
                          'eval_loss': round(stats[last_eval_step]['eval_loss'].item(), 3) if last_eval_step is not None else 'N/A',
                          })

    return dict(stats)

CFG['seed'] = SEEDS[0]
train_standard_opt(CFG, optax.inject_hyperparams(optax.sgd)(0.1))

  0%|                                                                                                                                                                       | 0/5000 [00:00<?, ?it/s]