This notebook is for experiments probing things other than performance, such as checking conditions and assumptions.

In [1]:
# # for use in google colab!!
# !git clone https://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@github.com/edogariu/meta-opt
# !pip install ./meta-opt
# !pip install tensorflow-text ml_collections clu sentencepiece  # for WMT
# from google.colab import drive
# drive.mount('/content/drive')
# DIR_PREFIX = "drive/My Drive/meta-opt"

# # for extra one-time setup in colab
# !git clone https://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@github.com/edogariu/meta-opt
# !mkdir meta-opt/data
# !mkdir meta-opt/datasets
# !cp -r "meta-opt" "drive/My Drive/"
# !pip install kora -q  # library from https://stackoverflow.com/questions/62596466/how-can-i-run-notebooks-of-a-github-project-in-google-colab to help get ID
# from kora.xattr import get_id
# fid = get_id(f"{dir_prefix}meta_opt.ipynb")
# print("https://colab.research.google.com/drive/"+fid)

In [2]:
from time import perf_counter
from itertools import accumulate
from collections import defaultdict, deque
from copy import deepcopy
import tqdm
import matplotlib.pyplot as plt
import pickle as pkl

import numpy as np
import tensorflow as tf
import jax
import jax.numpy as jnp
import optax

from meta_opt.controllers.utils import append
from meta_opt.nn.trainer import create_train_state, train_step, reset_model, eval
from meta_opt.problems import mnist, cifar10, wmt

# Background code 
Code to reproducibly train the model and collect stats.

In [4]:
def set_seed(seed):
    if seed is None: 
        seed = np.random.randint()
        print('seed set to {}'.format(seed))
    np.random.seed(seed)
    tf.random.set_seed(seed)
    rng = jax.random.PRNGKey(seed)
    return rng, seed


def run_trial(seed, problem_name, optimizer, experiment_step_fn):
    """
    `experiment_step_fn` is a function mapping `tstate, batch, carry -> tstate, (loss, grads, stats, carry)` for each timestep

    to run regular experiment with no funny business or extra logging, one would set `experiment_step_fn` to be `gradient_descent` but also returning an empty dict
    """
    
    rng, seed = set_seed(seed)
    init_rng, rng = jax.random.split(rng)

    cfg = {'num_iters': NUM_ITERS, 'batch_size': BATCH_SIZE, 'eval_every': EVAL_EVERY, 'reset_every': RESET_EVERY, 'num_eval_iters': -1}

    # get dataset and model
    if 'MNIST' in problem_name:
        train_ds, test_ds, example_input, loss_fn, metric_fns = mnist.load_mnist(cfg, dataset_dir=f'{DIR_PREFIX}/datasets')
        # model = mnist.MLP([28 * 28, 100, 100, 10])
        model = mnist.MLP([28 * 28, 1, 10])
    elif 'CIFAR' in problem_name:
        train_ds, test_ds, example_input, loss_fn, metric_fns = cifar10.load_cifar10(cfg, dataset_dir=f'{DIR_PREFIX}/datasets')
        model = cifar10.VGG(stages=((32, 32), (64, 64), (128, 128)), layer_dims=[128, 10], drop_last_activation=True, dropout=0.1)
    elif 'WMT' in problem_name:
        train_ds, test_ds, example_input, loss_fn, metric_fns = wmt.load_wmt(cfg, dataset_dir=f'{DIR_PREFIX}/datasets')
        model = wmt.make_transformer(num_heads=8, num_layers=6, emb_dim=256, qkv_dim=256, mlp_dim=1024)
    else:
        raise NotImplementedError(problem_name)
    
    tstate = create_train_state(init_rng, model, example_input, optimizer, loss_fn, metric_fns=metric_fns)
    del init_rng

    args = {'seed': seed,
            'model': str(model),
            'params': sum(x.size for x in jax.tree_util.tree_leaves(tstate.params)),
            'dataset': problem_name,
            'num_iters': NUM_ITERS,
            'eval_every': EVAL_EVERY,
            'batch_size': BATCH_SIZE,
            'reset_every': RESET_EVERY}
    print(args['params'], 'params in the model')

    stats = defaultdict(dict)
    args['optimizer_args'] = deepcopy(tstate.opt_state.hyperparams)
    args['optimizer_args']['name'] = 'standard'
    stats['args'] = args

    carry = None
    t0 = perf_counter()
    for t, batch in enumerate(pbar := tqdm.tqdm(train_ds.as_numpy_iterator(), total=args['num_iters'])):
        t += 1
    
        if t % RESET_EVERY == 0:
            reset_rng, rng = jax.random.split(rng)
            tstate = reset_model(reset_rng, tstate)
            del reset_rng

        # --------------------------------------------------------------------------------
        tstate, (loss, grads, s, carry) = experiment_step_fn(tstate, batch, carry)  # this is what changes between different experiments
        # --------------------------------------------------------------------------------
        print(s['sequential_stability'])
        # update all the stats
        s['timestamp'] = perf_counter() - t0
        s['loss'] = loss
        if t % EVAL_EVERY == 0: 
            s['eval_loss'], s['eval_acc'] = 0., 0.
            n = 0
            for batch in test_ds.as_numpy_iterator():
                loss, acc = eval(tstate, batch)
                s['eval_loss'] += loss
                s['eval_acc'] += acc
                n += 1
            s['eval_loss'] /= n
            s['eval_acc'] /= n
            s['grad_sq_norm'] = sum(jax.tree_util.tree_flatten(jax.tree_map(lambda g: (g * g).sum(), grads))[0])
        stats[t] = s
    
    return dict(stats)

def run_experiment(seeds, problem_name, optimizer, experiment_step_fn, fname=f'{DIR_PREFIX}/data/{NAME}_sequentialstability_raw.pkl'):
    results = []
    
    for s in seeds:
        results.append(run_trial(s, NAME, optax.inject_hyperparams(optax.sgd)(0.1), sequential_stability))
    
        if len(results) > 0:
            pkl.dump(results, open(fname, 'wb'))
            print(f'Saved checkpoint for seed #{s}')
            
    return results


# Sequential Stability

A time-varying linear dynamical system with dynamics $A_1, \ldots, A_T$ is $(\kappa, \gamma)$-sequentially stable if  for all intervals $I = [r, s]\subseteq [T]$,
$$
\left \|\prod_{t=s}^{r} A_t\right \| \le \kappa^2 (1-\gamma)^{|I|}$$
We check if the LTV in meta-opt is indeed sequentially stable.

In [5]:
MAX_LEN = 16  # for computational reasons, we will only compute with lengths up to this value
DELTA = 0.001  # (1-delta) decay factor for state

In [None]:
@jax.jit
def forward_and_backward_with_hessian(tstate, batch):
    if tstate.rng is not None:
        next_key, dropout_key = jax.random.split(tstate.rng)
        tstate = tstate.replace(rng=next_key)
    else: dropout_key = None
    def loss_fn(params):
        yhat = tstate.apply_fn({'params': params}, batch['x'], train=True, rngs={'dropout': dropout_key})
        loss = tstate.loss_fn(yhat, batch['y'])
        return loss
    loss, grads = jax.value_and_grad(loss_fn)(tstate.params)

    p, td = jax.tree_util.tree_flatten(tstate.params)
    def loss_fn_from_flat(params_flat):  # for hessian computation
        q = []
        n = 0
        for v in p:
            d = np.prod(v.shape)
            q.append(params_flat[n: n + d].reshape(v.shape))
            n += d
        params = jax.tree_util.tree_unflatten(td, q)
        return loss_fn(params)
        
    hessians = jax.hessian(loss_fn_from_flat)(jnp.concatenate([_p.reshape(-1) for _p in p], axis=0))
    return tstate, (loss, grads, hessians)

@jax.jit
def sequential_stability(tstate, batch, carry):
    # the vanilla stuff, but also computing hessian
    stats = {}
    tstate, (loss, grads, hessians) = forward_and_backward_with_hessian(tstate, batch)
    tstate = tstate.apply_gradients(grads=grads)

    # use hessian to compute transition matrix and append to the buffer. note that this is using batch averages
    def f(H, eta, delta, carry):
        I = jnp.eye(H.shape[0])
        A = jnp.block([[(1 - delta) * I, 0 * I, -eta * I], [I, 0 * I, 0 * I], [H, -H, 0 * I]])  # transition matrix for this step
        carry = A @ append(carry, jnp.eye(A.shape[0]))  # append an entry of 1 to the right, then left multiply each entry by A. this dynamically handles the cumprod
        spectral_norms = jnp.linalg.norm(carry, axis=(1, 2), ord=2)
        return carry, spectral_norms
    
    H = hessians # + 2 * beta * jnp.eye(hessians.shape[0])  # TODO CHECK THIS!!!
    if carry is None: carry = jnp.zeros((MAX_LEN, H.shape[0] * 3, H.shape[0] * 3))
    carry, spectral_norms = f(H, tstate.opt_state.hyperparams['learning_rate'], DELTA, carry)
    print(carry.shape, spectral_norms.shape, spectral_norms)
    stats['sequential_stability'] = spectral_norms

    return tstate, (loss, grads, stats, carry)

results = run_experiment(SEEDS, NAME, optax.inject_hyperparams(optax.sgd)(0.1), sequential_stability)

805 params in the model


  0%|                                                                                                                                                                        | 0/100 [00:00<?, ?it/s]

(16, 2415, 2415) (16,) Traced<ShapedArray(float32[16])>with<DynamicJaxprTrace(level=1/0)>


#### 