In [None]:
# # for use in google colab!!
# !git clone https://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@github.com/edogariu/meta-opt
# !pip install ./meta-opt
# !pip install tensorflow-text ml_collections clu sentencepiece  # for WMT
# from google.colab import drive
# drive.mount('/content/drive')
# DIR_PREFIX = "drive/My Drive/meta-opt"

# # for extra one-time setup in colab
# !git clone https://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@github.com/edogariu/meta-opt
# !mkdir meta-opt/data
# !mkdir meta-opt/datasets
# !cp -r "meta-opt" "drive/My Drive/"
# !pip install kora -q  # library from https://stackoverflow.com/questions/62596466/how-can-i-run-notebooks-of-a-github-project-in-google-colab to help get ID
# from kora.xattr import get_id
# fid = get_id(f"{dir_prefix}meta_opt.ipynb")
# print("https://colab.research.google.com/drive/"+fid)

In [None]:
from time import perf_counter
from collections import defaultdict
from copy import deepcopy
import tqdm
import matplotlib.pyplot as plt
import pickle as pkl

import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
import optax

from meta_opt.nn.trainer import create_train_state, gradient_descent, reset_model, eval
from meta_opt.problems import mnist, cifar10, wmt

In [1]:
def set_seed(seed):
    if seed is None: 
        seed = np.random.randint()
        print('seed set to {}'.format(seed))
    np.random.seed(seed)
    tf.random.set_seed(seed)
    rng = jax.random.PRNGKey(seed)
    return rng, seed

def get_problem(seed, name, optimizer):
    rng, seed = set_seed(seed)
    init_rng, rng = jax.random.split(rng)

    # get dataset and model
    if 'MNIST' in name:
        train_ds, test_ds, example_input, loss_fn, acc_fn = mnist.load_mnist(NUM_ITERS, BATCH_SIZE, dataset_dir=f'{DIR_PREFIX}/datasets')
        model = mnist.MLP([28 * 28, 100, 100, 10])
    elif 'CIFAR' in name:
        train_ds, test_ds, example_input, loss_fn, acc_fn = cifar10.load_cifar10(NUM_ITERS, BATCH_SIZE, dataset_dir=f'{DIR_PREFIX}/datasets')
        model = cifar10.VGG(stages=((32, 32), (64, 64), (128, 128)), layer_dims=[128, 10], drop_last_activation=True, dropout=0.1)
    elif 'WMT' in name:
        train_ds, test_ds, example_input, loss_fn, acc_fn = wmt.load_wmt(NUM_ITERS, BATCH_SIZE, dataset_dir=f'{DIR_PREFIX}/datasets')
        model = wmt.make_transformer(num_heads=8, num_layers=6, emb_dim=256, qkv_dim=256, mlp_dim=1024)
    else:
        raise NotImplementedError(name)

    tstate = create_train_state(init_rng, model, example_input, optimizer, loss_fn, acc_fn=acc_fn)
    del init_rng

    args = {'seed': seed,
            'model': str(model),
            'params': sum(x.size for x in jax.tree_util.tree_leaves(tstate.params)),
            'dataset': name,
            'num_iters': NUM_ITERS,
            'eval_every': EVAL_EVERY,
            'batch_size': BATCH_SIZE,
            'reset_every': RESET_EVERY,
            'print_every': PRINT_EVERY}

    return tstate, train_ds, test_ds, rng, args

In [None]:
def train_standard_opt(seed, problem_name, optimizer):
    tstate, train_ds, test_ds, rng, args = get_problem(seed, problem_name, optimizer)
    
    stats = defaultdict(dict)
    args['optimizer_args'] = deepcopy(tstate.opt_state.hyperparams)
    args['optimizer_args']['name'] = 'standard'
    stats['args'] = args

    t0 = perf_counter()
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=args['num_iters'])):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        
        # update all the stats
        s = {}
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % EVAL_EVERY == 0: 
            s['eval_loss'], s['eval_acc'] = 0., 0.
            n = 0
            for batch in test_ds.as_numpy_iterator():
                loss, acc = eval(tstate, batch)
                s['eval_loss'] += loss
                s['eval_acc'] += acc
                n += 1
            s['eval_loss'] /= n
            s['eval_acc'] /= n
            s['grad_sq_norm'] = sum(jax.tree_util.tree_flatten(jax.tree_map(lambda g: (g * g).sum(), grads))[0])
        stats[t] = s
    
        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)