In [1]:
from collections import defaultdict
from copy import deepcopy
import tqdm
import matplotlib.pyplot as plt
import pickle as pkl

import numpy as np
import tensorflow as tf
import jax
import optax

from training.trainer import create_train_state, gradient_descent, reset_model, forward
from training.utils import cross_entropy, mse, MLP, CNN, load_mnist

from meta_opt import MetaOpt

### Todo
- add wall clock time to `stats`
- add MP

In [2]:
# import jax.numpy as jnp
# m = jnp.zeros((4,))
# tstate, train_ds, test_ds, rng, args = get_problem(0, optax.sgd(0.1))
# p = tstate.params

# def d_fn(_p):
#     n = _p.ndim
#     s = [1,] * (n + 1)
#     s[0] = 4
#     return jnp.tile(_p, s)

# d = jax.tree_map(d_fn, p)

# control = jax.tree_map(lambda s: (m[:, *[None for _ in range(s.ndim - 1)]] * s).sum(axis=0), d)
# print(control['dense 0']['kernel'].shape)

In [3]:
SEED = 1                      
NUM_ITERS = 8000
EVAL_EVERY = 100
BATCH_SIZE = 2048
RESET_EVERY = 4000
PRINT_EVERY = 1000

def set_seed(seed):
    if seed is None: 
        seed = np.random.randint()
        print('seed set to {}'.format(seed))
    np.random.seed(seed)
    tf.random.set_seed(seed)
    rng = jax.random.PRNGKey(seed)
    return rng, seed

# MNIST
def get_problem(seed, optimizer):
    rng, seed = set_seed(seed)

    # get dataset
    train_ds, test_ds, loss_fn, input_dims = load_mnist(NUM_ITERS, BATCH_SIZE)
    
    # define model, optimizer, and train state
    init_rng, rng = jax.random.split(rng)
    model = MLP([28 * 28, 100, 100, 10])
    tstate = create_train_state(init_rng, model, input_dims, optimizer, loss_fn)
    del init_rng

    args = {'seed': seed,
            'model': str(model),
            'dataset': 'MNIST',
            'num_iters': NUM_ITERS,
            'eval_every': EVAL_EVERY,
            'batch_size': BATCH_SIZE,
            'reset_every': RESET_EVERY,
            'print_every': PRINT_EVERY}

    return tstate, train_ds, test_ds, rng, args

# Standard Optimizers

In [4]:
def train_standard_opt(seed, optimizer):
    tstate, train_ds, test_ds, rng, args = get_problem(seed, optimizer)
    
    stats = defaultdict(dict)
    args['optimizer_args'] = deepcopy(tstate.opt_state.hyperparams)
    args['optimizer_args']['name'] = 'standard'
    stats['args'] = args
    
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        
        # update all the stats
        s = {}
        s['loss'] = loss
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        stats[t] = s
    
        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Meta-Opt

In [5]:
def train_meta_opt(seed, m_method: str, meta_lr: float, H: int, HH: int, initial_lr: int):
    optimizer = optax.sgd(learning_rate=initial_lr)
    tstate, train_ds, test_ds, rng, args = get_problem(seed, optimizer)
    
    stats = defaultdict(dict)
    args['optimizer_args'] = {'name': 'meta',
                              'initial_lr': initial_lr,
                              'm_method': m_method,
                              'meta_lr': meta_lr,
                              'H': H,
                              'HH': HH
                              }
    stats['args'] = args

    meta_opt = MetaOpt(tstate, H=H, HH=HH, meta_lr=meta_lr, delta=1e-5, m_method=m_method)

    stats = defaultdict(dict)
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            meta_opt = meta_opt.episode_reset()
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        tstate = meta_opt.meta_step(tstate, grads, batch)
        
        # update all the stats
        s = {}
        s['loss'] = loss
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        if m_method == 'scalar': s['M'] = meta_opt.cstate.M.reshape(-1)
        stats[t] = s

        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Hypergradient Descent

In [6]:
def train_hgd(seed, initial_lr: float, hypergrad_lr: float):

    optimizer = optax.inject_hyperparams(optax.sgd)(learning_rate=initial_lr)
    tstate, train_ds, test_ds, rng, args = get_problem(seed, optimizer)

    stats = defaultdict(dict)
    args['optimizer_args'] = {'name': 'hgd',
                              'initial_lr': initial_lr,
                              'hypergrad_lr': hypergrad_lr,
                              }
    stats['args'] = args
    
    prev_grads = None
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=len(train_ds))):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        tstate, (loss, grads) = gradient_descent(tstate, batch)
        if prev_grads is not None: 
            hypergrad = -sum([(g1 * g2).sum() for g1, g2 in zip(jax.tree_util.tree_leaves(grads), jax.tree_util.tree_leaves(prev_grads))])
            tstate.opt_state.hyperparams['learning_rate'] -= hypergrad_lr * hypergrad
        prev_grads = grads
        
        # update all the stats
        s = {}
        s['loss'] = loss
        s['lr'] = tstate.opt_state.hyperparams['learning_rate'].item()
        if t % EVAL_EVERY == 0: s['eval_loss'] = sum(forward(tstate, batch) for batch in test_ds.as_numpy_iterator()) / len(test_ds)
        stats[t] = s
    
        # print if we gotta
        if t % PRINT_EVERY == 0 and t > 0:
            idxs = [stats[i] for i in range(t - PRINT_EVERY, t) if i in stats]
            avg_train_loss = np.mean([s['loss'] for s in idxs if 'loss' in s])
            avg_eval_loss = np.mean([s['eval_loss'] for s in idxs if 'eval_loss' in s])
            print(f'iters {t - PRINT_EVERY} - {t}')
            print(f'\tavg train loss: {avg_train_loss}')
            print(f'\tavg eval loss: {avg_eval_loss}')
        pbar.set_postfix({'loss': round(s['loss'].item(), 3)})

    return dict(stats)

# Run

In [7]:
print('currently available stats:')
for k in list(globals().keys()):
    if 'stats' in k: print('\t', k)

currently available stats:


In [None]:
# sgd_stats = train_standard_opt(SEED, optax.inject_hyperparams(optax.sgd)(learning_rate=0.2)) 
# adam_stats = train_standard_opt(SEED, optax.inject_hyperparams(optax.adam)(learning_rate=0.001)) 
scalar_mo_stats = train_meta_opt(SEED, 'scalar', meta_lr=0.015, H=8, HH=3, initial_lr=0.2)
# diagonal_mo_stats = train_meta_opt(SEED, 'diagonal', meta_lr=1., H=4, HH=2, initial_lr=0.2)
# hgd_stats = train_hgd(SEED, initial_lr=0.2, hypergrad_lr=1e-4)

  3%|█▊                                                                     | 201/8000 [00:11<09:02, 14.36it/s, loss=0.173]

# Load/Save

In [None]:
# name = 'scalar_mo_stats'

# assert name in globals()
# with open(f'./data/{name}.pkl', 'wb') as f:
#     pkl.dump(globals()[name], f)
#     print(f'dumped {name} to {f}')

In [None]:
# sgd_stats = pkl.load(open('./data/sgd_stats.pkl', 'rb'))
# adam_stats = pkl.load(open('./data/adam_stats.pkl', 'rb'))

# Plot

In [None]:
things_to_plot = [
                    'loss', 
                    'eval_loss', 
                    'lr',
                    ]
experiments_to_plot = [
                       'sgd', 
                       'adam', 
                       # 'hgd', 
                       # 'diagonal_mo', 
                       'scalar_mo',
                      ]

fig, ax = plt.subplots(len(things_to_plot), 1, figsize=(10, 24))
for e in experiments_to_plot:
    try: s = globals()[e + '_stats']
    except: continue
    ts = [int(k) for k in s.keys() if k != 'args']
    for _ax, p in zip(ax, things_to_plot):
        _ts, _vals = [], []
        for t in ts:
            if p in s[t]:
                _ts.append(t)
                _vals.append(s[t][p])
        if len(_ts) == 0:
            print(f'{e} has no statistic \"{p}\"')
            continue
        _ax.set_title(p)
        _ax.plot(_ts, _vals, label=e)
        _ax.legend()
plt.show()