In [102]:
from collections import defaultdict, deque
import tqdm
import matplotlib.pyplot as plt

import numpy as np
import jax
import jax.numpy as jnp
import optax
from flax import struct
import tensorflow as tf

from controllers._base import ControllerState

from training.trainer import TrainState, reset_model, create_train_state, forward, forward_and_backward, apply_gradients
from training.hgd import HGDState, hypergrad_step
from training.utils import cross_entropy, mse, load_mnist, MLP, CNN

In [109]:
# define model, optimizer, and train state
rng = jax.random.PRNGKey(1)
model = MLP([28 * 28, 100, 100, 10])
optimizer = optax.inject_hyperparams(optax.sgd)(learning_rate=0.01, momentum=0.)
init_rng, rng = jax.random.split(rng)
tstate = create_train_state(init_rng, model, [28, 28], optimizer, None)
del init_rng

## Define 

In [119]:
def a(z):
    return z ** 2

for i in range(5):
    x = x.at[i].set(jax.tree_util.Partial(a))

TypeError: JAX only supports number and bool dtypes, got dtype <class 'jax._src.tree_util.Partial'> in array

## Define GPC Controller for Meta-Opt

- "state" means parameters at current iteration
- "cost_fn" means function to optimize at current iteration, can be MSE on current batch or something
- "disturbance" means gradient at previous iteration

In [104]:
class MetaOptGPCState(ControllerState):
    M: jnp.ndarray  # disturbance-feedback control matrices
    state_history: jnp.ndarray  # state history, for hallucinations
    cost_fn_history: deque  # history of cost functions, for hallucinations
    disturbance_history: jnp.ndarray  # disturbance history
    
    state_dim: int = struct.field(pytree_node=False)
    control_dim: int = struct.field(pytree_node=False)
    H: int = struct.field(pytree_node=False)  # history of the controller, how many past disturbances to use for control
    HH: int = struct.field(pytree_node=False)  # history of the system, how many hallucination steps to take
    t: int  # time counter (for decaying learning rate)
    lr: float
    decay_lr: bool = struct.field(pytree_node=False)
    
    @classmethod
    def create(cls, 
               state_dim: int,
               control_dim: int,
               m_method: str,
               H: int,
               HH: int,
               lr: float = 0.008,
               decay_lr: bool = True,):
        
        if m_method == 'scalar': M = jnp.zeros((H, 1))
        elif m_method == 'diagonal': M = jnp.zeros((H, state_dim))
        elif m_method == 'full': M = jnp.zeros((H, control_dim, state_dim))
        else: raise NotImplementedError(m_method)
        
        state_history = jnp.zeros((HH, state_dim))  # past HH states ordered increasing in time
        disturbance_history = jnp.zeros((H + HH, state_dim))  # Past H + HH noises ordered increasing in time
        cost_fn_history = deque([], maxlen=HH)  # past HH cost functions ordered increasing in time
        tx = optax.inject_hyperparams(optax.sgd)(learning_rate=lr)  # M optimizer
        opt_state = tx.init(M)
        
        return cls(M=M, state_history=state_history, disturbance_history=disturbance_history, cost_fn_history=cost_fn_history,
                   state_dim=state_dim, control_dim=control_dim, H=H, HH=HH, t=0,
                   lr=lr, decay_lr=decay_lr, tx=tx, opt_state=opt_state)

@jax.jit
def _compute_control(M, disturbances):
    if len(M.shape) == 3: control = jnp.tensordot(M, disturbances, axes=([0, 2], [0, 1]))
    else: control = (M * disturbances).sum(axis=0)
    return control

@jax.jit
def _compute_loss(cstate: MetaOptGPCState, 
                  intial_state,
                  step_fn,  # takes in [tstate, cost_fn] and outputs [new_tstate,]. most likely via gradient descent
                  cost_fn,  # cost function to evaluate hallucinated final state with
                 ):
    """FINSIH"""
    def _evolve(tstate, h):
        return cstate.A @ state + _compute_control(cstate.M, jax.lax.dynamic_slice_in_dim(cstate.disturbance_history, h, cstate.H)), None
    final_tstate, _ = jax.lax.scan(_evolve, intial_tstate, jnp.arange(cstate.HH - 1))
    return quad_loss(final_state, _action(final_state, cstate.HH - 1))

def update(cstate: MetaOptGPCState,
           params,
           disturbance,
           cost_fn,
          ):
    
    updates, new_opt_state = cstate.tx.update(grads, cstate.opt_state, cstate.M)
    M = optax.apply_updates(cstate.M, updates[0])

    

    return cstate.replace(M=M, opt_state=new_opt_state, disturbance_history=disturbance_history, t=cstate.t+1)   
    

def get_control(cstate: MetaOptGPCState):
    return _compute_control(cstate.M, jax.lax.dynamic_slice_in_dim(cstate.disturbance_history, -cstate.H, cstate.H))

def reset(cstate: MetaOptGPCState):
    state_history = deque([], maxlen=cstate.HH)  # past HH states ordered increasing in time
    disturbance_history = jnp.zeros((cstate.H + cstate.HH, cstate.state_dim))  # Past H + HH noises ordered increasing in time
    cost_fn_history = deque([], maxlen=cstate.HH)  # past HH cost functions ordered increasing in time
    return cstate.replace(state_history=state_history, disturbance_history=disturbance_history, cost_fn_history=cost_fn_history)

IndentationError: expected an indented block after function definition on line 58 (1056905210.py, line 65)

In [None]:
## what fns will we need

def gd_step(tstate, batch):  # take gd update step w current opt on batch
    return tstate

def gpc_cost(cstate, ...):  # used to take derivative w.r.t. M's in order to update gpc controller. should HALLUCINATE!!!
    return cost

def update(cstate, params, disturbance, cost_fn):
    # append params
    # compute update using cstate's disturbance and cost fn histories
    # append disturbance, cost fn, and 


In [None]:
# how meta opt train loop should go
for step in range(tsteps):
    tstate = gd_step(tstate, batch)  # maybe make this to take input arbitrary functions, not just batches
    tstate = (1-delta) * tstate + get_control(...)  # use disturbances up to and including `step-1`, but NOT the gradient we just calculated above
    gpc = update_gpc(...)  # hallucainate. use cost fns during history for GD steps, but compute final state loss with cost from 2 lines ago
    # append grad to disturbances!
    