In [1]:
from collections import defaultdict, deque
from typing import List, Tuple
import tqdm
import matplotlib.pyplot as plt

import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import struct
import tensorflow as tf

from controllers._base import ControllerState
from controllers.utils import append

from training.trainer import TrainState, reset_model, create_train_state, forward, forward_and_backward, apply_gradients
from training.hgd import HGDState, hypergrad_step
from training.utils import cross_entropy, mse, load_mnist, MLP, CNN

## Things we keep track of
- states (params at current iterations)
- disturbances (grads at previous iterations)
- evolve functions (based on lrs and cost fns at current iterations)

In [47]:
def gd(params, lr, cost_fn):
    """Gradient descent. """
    cost, grad = jax.value_and_grad(cost_fn)(params)
    new_params = jax.tree_map(lambda p, g: p - lr * g, params, grad)
    return (new_params, grad, cost,)

# @jax.jit
def slice_pytree(pytree, start_idx, slice_size):
    return jax.tree_map(lambda p: jax.lax.dynamic_slice_in_dim(p, start_idx, slice_size), pytree)
# slice_pytree = jax.jit(slice_pytree, static_argnums=(1, 2))

In [25]:
class MetaOptGPCState(ControllerState):
    M: jnp.ndarray  # pytree of disturbance-feedback control matrices
    
    H: int = struct.field(pytree_node=False)  # history of the controller, how many past disturbances to use for control
    HH: int = struct.field(pytree_node=False)  # history of the system, how many hallucination steps to take
    lr: float
    
    @classmethod
    def create(cls, 
               params,
               m_method: str,
               H: int,
               HH: int,
               lr: float = 0.008,):

        def make_m(p):
            if m_method == 'scalar': shape = (1,) * p.ndim 
            elif m_method == 'diagonal': shape = p.shape
            else: raise NotImplementedError(m_method)
            return jnp.zeros((H, *shape))
        
        M = jax.tree_map(make_m, params)
        tx = optax.sgd(learning_rate=lr)  # M optimizer
        opt_state = tx.init(M)
        
        return cls(M=M,
                   H=H, HH=HH, 
                   lr=lr, tx=tx, opt_state=opt_state)

@jax.jit
def compute_control(M, disturbances):
    control = jax.tree_map(lambda m, d: (m * d).sum(axis=0), M, disturbances)
    # control = (M * disturbances).sum(axis=0)
    return control

def _compute_loss(M, H, HH, initial_params, 
                  disturbances,  # past H + HH disturbances
                  evolve_fns,  # past HH evolve functions, starting at the one that would have been used to evolve `initial_params`
                  cost_fn):
    params = initial_params
    for h in range(HH):
        params = jax.tree_map(lambda p, c: p + c, evolve_fns[h](params), compute_control(M, slice_pytree(disturbances, h, H)))
    loss = cost_fn(params)
    return loss

_grad_fn = jax.grad(_compute_loss, (0,))

# @jax.jit
def update(cstate: MetaOptGPCState,
           initial_params,  # params from HH steps ago
           disturbances,  # past H + HH disturbances
           evolve_fns,  # past HH evolve functions, starting at the one that would have been used to evolve `initial_params`
           cost_fn
          ):
    
    grads = _grad_fn(cstate.M, cstate.H, cstate.HH, initial_params, disturbances, evolve_fns, cost_fn)
    updates, new_opt_state = cstate.tx.update(grads, cstate.opt_state, cstate.M)
    M = optax.apply_updates(cstate.M, updates[0])
    cstate = cstate.replace(M=M, opt_state=new_opt_state)   
    return cstate

In [26]:
class MetaOpt:
    param_history: Tuple
    grad_history: jnp.ndarray
    evolve_fn_history: Tuple
    cstate: MetaOptGPCState
    delta: float
    t: int

    def __init__(self,
                 initial_params,
                 H: int, HH: int,
                 meta_lr: float, delta: float,
                 m_method: str):
        self.param_history = (None,) * HH
        self.grad_history = jax.tree_map(lambda p: jnp.zeros((H + HH, *p.shape)), initial_params)
        self.evolve_fn_history = (None,) * HH
        self.delta = delta
        self.t = 0

        assert m_method in ['scalar', 'diagonal']
        self.cstate = MetaOptGPCState.create(initial_params, m_method, H, HH, lr=meta_lr)
        pass

    def meta_step(self, 
                  params,  # params after a step of gd
                  grads,  # grads from the step of gd that resulted in `params`
                  lr, cost_fn,  # lr and cost fn from step of gd that resulted in `params`
                 ):        
        if self.t >= self.cstate.H + self.cstate.HH:
            control = compute_control(self.cstate.M, slice_pytree(self.grad_history, self.cstate.HH, self.cstate.H))  # use past H disturbances
            params = jax.tree_map(lambda p, c: (1 - self.delta) * p + c, params, control)
            self.cstate = update(self.cstate, self.param_history[0], self.grad_history, self.evolve_fn_history, jax.tree_util.Partial(cost_fn))

        self.param_history = append(self.param_history, params)
        self.grad_history = jax.tree_map(lambda h, g: append(h, g), self.grad_history, grads)

        self.evolve_fn_history = append(self.evolve_fn_history, jax.tree_util.Partial(lambda p: gd(p, lr, cost_fn)[0]))
        self.t += 1
        return params

# Run Things

In [52]:
def run(seed, dim, use_meta_opt):
    np.random.seed(seed)

    def f(x): return jnp.abs(x) ** 0.5 * jnp.sign(x)
    # def f(x): return x

    num_iters, batch_size = 500, 256
    
    inputs = jnp.array(np.random.randn(num_iters, batch_size, dim))
    model = MLP([dim, dim, dim, dim])
    tstate = create_train_state(jax.random.PRNGKey(seed), model, [dim,], optax.sgd(learning_rate=0., momentum=0.), None)
    params = tstate.params
    lr = 0.01

    if use_meta_opt: metaopt = MetaOpt(params, H=3, HH=1, meta_lr=20., delta=0., m_method='diagonal')

    losses = []
    for i in tqdm.trange(num_iters):
        x = inputs[i]
        cost_fn = jax.tree_util.Partial(lambda p: ((tstate.apply_fn({'params': p}, x) - f(x)) ** 2).mean())
        params, grad, loss = gd(params, lr, cost_fn)
        losses.append(loss)
        if use_meta_opt: params = metaopt.meta_step(params, grad, lr, cost_fn)
    return losses            

In [53]:
seed, dim = 2, 500
losses = run(seed, dim, False)
# mo_losses = run(seed, dim, True)

 29%|███████████████████████████▍                                                                  | 146/500 [00:02<00:06, 52.40it/s]

KeyboardInterrupt



In [None]:
plt.plot(range(len(losses)), losses)
plt.plot(range(len(mo_losses)), mo_losses)