# Neural Network Training

In [10]:
import logging
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)  # set level to INFO for wordy
import matplotlib.pyplot as plt
from IPython.display import HTML

import numpy as np
import jax.numpy as jnp
import torch.optim as optim

from extravaganza.dynamical_systems import LinearRegression, MNIST

from extravaganza.observables import Observable, TimeDelayedObservation, FullObservation, Trajectory
from extravaganza.sysid import Lifter, LiftedController, OfflineSysid
from extravaganza.controllers import LQR, HINF, BPC, GPC, RBPC, EvanBPC, ConstantController
from extravaganza.rescalers import ADAM, D_ADAM, DoWG, FIXED_RESCALE
from extravaganza.stats import Stats
from extravaganza.utils import ylim, render, append, opnorm, dare_gain, least_squares
from extravaganza.experiments import Experiment

# seeds for randomness. setting to `None` uses random seeds
SYSTEM_SEED = None
CONTROLLER_SEED = None
SYSID_SEED = None

## System
Here, we tune the parameters of a gradient descent algorithm training a neural network. We can train either a linear regression or an MLP or CNN MNIST model. Any optimizer can be used, such as `SGD` or `Adam`, and any parameter can be tuned, such as `lr` or `momentum`. 

Which optimizer to use is specified in the `make_optimizer` argument, and how to update and tune things is specified in the `apply_control` argument.

At the moment, we apply a 2-dimensional control $u = (u_0, u_1)$ that dictates 2 parameters of the learning rate schedule, given as
$$\eta_t := \frac{u_0}{1 + u_1 \cdot \sqrt{t}},$$
where $u_0$ is the initial learning rate and $u_1$ is a decay rate.

## Hyperparameters

In [14]:
name = 'nn'
filename = '../logs/{}.pkl'.format(name)

def get_experiment_args():
    # --------------------------------------------------------------------------------------
    # ------------------------    EXPERIMENT HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    num_trials = 1
    T = 2000  # total timesteps
    T0 = 1000  # number of timesteps to just sysid for our methods
    reset_condition = lambda t: t % 20 == 0  # when to reset the system (which means fresh LR/MNIST model params)
    use_multiprocessing = False  # unsure if this works in jupyter notebooks
    render_every = None

    # --------------------------------------------------------------------------------------
    # --------------------------    SYSTEM HYPERPARAMETERS    ------------------------------
    # --------------------------------------------------------------------------------------

    initial_lr = 0.1
    initial_decay = 0.8
    initial_control = jnp.array([initial_lr, initial_decay])
    du = initial_control.shape[0]
    
    make_optimizer = lambda model: optim.SGD(model.parameters(), lr=initial_lr)
    def apply_control(control, system): system.opt.param_groups[0]['lr'] = max(0, control[0].item()) / (1 + max(0., control[1].item()) * system.episode_t ** 0.5)

    make_system = lambda : LinearRegression(make_optimizer, apply_control,
                                            dataset = 'generated', 
                                            repeat = 20,
                                            eval_every = 1, seed=SYSTEM_SEED)   

#     make_system = lambda : MNIST(make_optimizer, apply_control,
#                                  model_type = 'MLP', batch_size = 64,
#                                  repeat = 5,
#                                  eval_every=None, seed=SYSTEM_SEED)   # best is something like (0.5, 0.05) or (0.2, 0)

    hh = 3
    observable = TimeDelayedObservation(hh = hh, control_dim=du, time_embedding_dim=8,
                                        use_states=False, use_cost_diffs=False,
                                        use_costs=True, use_controls=True, use_time=True)
    do = observable.obs_dim  # dimension of observations to lift from

    # --------------------------------------------------------------------------------------
    # ------------------------    LIFT/SYSID HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    dl = 4  # dimension of state to lift to

    bounds = [(0, 1), (0, 1)]
    exploration_args = {'scales': 0.1, 'bounds': bounds, 'avg_len': 3,}
    sysid_args = {
        'obs_dim': do,
        'control_dim': du,

        'exploration_args': {'random 1.0': exploration_args,
    #                          'impulse 0.25': exploration_args,
                            },

        'method': 'nn',
        'AB_method': 'regression',
        'deterministic': False,
        'isometric': True,
        
        'sigma': 0,
        'depth': 8,
        'num_iters': 16000,
        'batch_size': 256,
        'lifter_lr': 0.001,
        'hh': hh,
        'initial_control': initial_control,

        'seed': SYSID_SEED,
    }

    # --------------------------------------------------------------------------------------
    # ------------------------    CONTROLLER HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    h = 5 # controller memory length (# of w's to use on inference)
    m_update_rescaler = lambda : ADAM(alpha=0.00, betas=(0.9, 0.999), use_bias_correction=True)
    m0_update_rescaler = lambda : ADAM(alpha=0.004, betas=(0.9, 0.999), use_bias_correction=True)
    k_update_rescaler = lambda : ADAM(alpha=0.004, betas=(0.9, 0.999), use_bias_correction=True)
#     m_update_rescaler = lambda : FIXED_RESCALE(alpha=0.0)
#     m0_update_rescaler = lambda : FIXED_RESCALE(alpha=0.01)
#     k_update_rescaler = lambda : FIXED_RESCALE(alpha=0.0)

    nonlinear_bpc_args = {
        'h': h,  
        'method': 'REINFORCE',
        'initial_scales': (0, 0.01, 0),  # M, M0, K   (uses M0's scale for REINFORCE)
        'rescalers': (m_update_rescaler, m0_update_rescaler, k_update_rescaler),
#         'bounds': bounds,
        'initial_u': jnp.zeros(du),
        'decay_scales': False,
        'use_tanh': False,
        'use_stabilizing_K': True,
        'seed': CONTROLLER_SEED
    }
    
    # this is a bit of a mess at the minute, but here goes: 
    #         - `OfflineSysid` is a wrapper to do sysid phase followed by control,
    #         - `LiftedController` is a wrapper that lifts states before passing to the controller, and 
    #         - `EvanBPC` is the controller (can be replaced with `extravaganza.controllers.RBPC` as well)
    # I currently use lambdas as object generators to make them from scratch easily, but soon i will switch to actual
    # generators or using deepcopies or something :)
    
    make_controllers = {
#         '{}/{}'.format(*[round(v.item(), 2) for v in initial_control]): lambda sys: ConstantController(initial_control, do),
#         #         'Lifted LQR': lambda sys: OfflineSysid(lambda sysid: LiftedController(controller=LQR(sysid.A, sysid.B), lifter=sysid),
#                                           sysid=Lifter(state_dim=dl, **sysid_args), T0=T0),
#         'Lifted HINF': lambda sys: OfflineSysid(lambda sysid: LiftedController(controller=HINF(sysid.A, sysid.B), lifter=sysid),
#                                           sysid=Lifter(state_dim=dl, **sysid_args), T0=T0),
#         'Lifted GPC': lambda sys: OfflineSysid(lambda sysid: LiftedController(controller=GPC(sysid.A, sysid.B, decay=False, lr_scale=0.01, H=10), lifter=sysid),
#                                           sysid=Lifter(state_dim=dl, **sysid_args), T0=T0),
        'Lifted EvanBPC': lambda sys: OfflineSysid(lambda sysid: LiftedController(controller=EvanBPC(sysid.A, sysid.B, **nonlinear_bpc_args), lifter=sysid),
                                          sysid=Lifter(state_dim=dl, **sysid_args), T0=T0)
    }
    experiment_args = {    
        'make_system': make_system,
        'make_controllers': make_controllers,
        'observable': observable,
        'num_trials': num_trials,
        'T': T,
        'reset_condition': reset_condition,
        'reset_seed': None,
        'use_multiprocessing': use_multiprocessing,
        'render_every': render_every,
    }
    return experiment_args

## actually run the thing :)

In [15]:
# run
experiment = Experiment(name)
stats = experiment(get_experiment_args)

INFO: (EXPERIMENT) --------------------------------------------------
INFO: (EXPERIMENT) ----------------- TRIAL 0 -----------------------
INFO: (EXPERIMENT) --------------------------------------------------

INFO: (EXPERIMENT): testing Lifted EvanBPC
INFO: (EXPLORER) generating exploration control sequences using ['random'] w.p. [1.]
100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████| 5000/5000 [00:10<00:00, 494.84it/s]
INFO: (LIFTER): we will be linearizing in latent dimension 9 and linearly project down to embedding dimension 4
INFO: (LIFTER): we are imposing simplification as a hard constraint on the latent space via isometric NN
INFO: (LIFTER): using "regression" method to get the AB matrices during each training step
  0%|                                                                                                                       | 0/2000 [00:00<?, ?it/s]INFO: (EXPERIMENT): reset at t=0!
  1%|▋               

 40%|█████████████████████████▏                                     | 800/2000 [00:51<01:14, 16.08it/s, control=[0.15492101 0.661502  ], cost=0.0886]INFO: (EXPERIMENT): reset at t=800!
 41%|█████████████████████████▍                                    | 820/2000 [00:53<01:15, 15.69it/s, control=[0.16999403 0.8429009 ], cost=0.00921]INFO: (EXPERIMENT): reset at t=820!
 42%|██████████████████████████                                    | 840/2000 [00:54<01:22, 14.14it/s, control=[0.09977766 0.88834006], cost=0.00921]INFO: (EXPERIMENT): reset at t=840!
 43%|██████████████████████████▋                                   | 860/2000 [00:55<01:07, 16.94it/s, control=[0.         0.94923395], cost=0.00921]INFO: (EXPERIMENT): reset at t=860!
 44%|███████████████████████████▎                                  | 880/2000 [00:57<01:07, 16.61it/s, control=[0.         0.82991743], cost=0.00921]INFO: (EXPERIMENT): reset at t=880!
 45%|████████████████████████████▎                                  | 900/2

regression (ret) :
||A||_op = 0.8092865943908691
||B||_F = 5.75260591506958
||A-BK||_op = 0.5834712386131287
eig(A) = [0.17203853 0.17203853 0.08198473 0.0704478 ]
svd(B) = [4.4987288 3.5852356]

moments :
||A||_op = 1.1270601749420166
||B||_F = 2.048187494277954
||A-BK||_op = 0.7723686695098877
eig(A) = [1.0015647  0.51330096 0.21545346 0.09838992]
svd(B) = [2.0478857  0.03514256]



 50%|████████████████████████████▌                            | 1000/2000 [02:03<2:27:45,  8.87s/it, control=[-0.01509719  0.01926161], cost=0.00921]INFO: (EXPERIMENT): reset at t=1000!
 50%|██████████████████████████████████████▌                                      | 1001/2000 [02:03<2:02:04,  7.33s/it, control=[nan nan], cost=inf]ERROR: (EXPERIMENT): state None or cost inf diverged
 50%|███████████████████████████████████████▌                                       | 1001/2000 [02:03<02:03,  8.09it/s, control=[nan nan], cost=inf]
INFO: 
ERROR: (EXPERIMENT): none of the trials succeeded.


In [13]:
# save args and stats!  --  note that to save the args, we actually save the `get_args` function. we can print the 
#                           source code later to see the hyperparameters we chose
# experiment.save(filename)

## Visualization
We keep track of the useful information through `Stats` objects, which can `register()` a variable to keep track of (which it does via calls to `update()`) and which can be aggregated via `Stats.aggregate()` for mean and variance statistics. 

We define below a plotting arrangement that plots all the desired quantities from both the system and controller.

In [None]:
def plot(experiment: Experiment):
    assert experiment.stats is not None, 'cannot plot the results of an experiment that hasnt been run'
    all_stats = experiment.stats
    
    # clear plot and calc nrows
    plt.clf()
    n = 4
    nrows = n + (len(all_stats) + 1) // 2
    fig, ax = plt.subplots(nrows, 2, figsize=(16, 6 * nrows))

    # plot stats
    for i, (method, stats) in enumerate(all_stats.items()):
        if stats is None: 
            logging.warning('{} had no stats'.format(method))
            continue
            
        stats.plot(ax[0, 0], 'lrs', label=method)
        stats.plot(ax[1, 0], 'costs', label=method)
        stats.plot(ax[1, 1], 'costs', label=method, plot_cummean=True)
        stats.plot(ax[2, 0], 'avg train losses since reset', label=method)
        stats.plot(ax[2, 1], 'avg val losses since reset', label=method)        
        
        stats.plot(ax[3, 0], 'states', label=method, plot_norm=True)  # norm of the "state"
        from extravaganza.sysid import LOSS_WEIGHTS
        for k in LOSS_WEIGHTS.keys(): stats.plot(ax[3, 1], k, label=k)  # various nn losses
            
        i_ax = ax[n + i // 2, i % 2]
        stats.plot(ax[0, 1], 'disturbances', label=method, plot_norm=True)
        idx = 1
        stats.plot(i_ax, '-K @ state', label='-K @ state', plot_idx=idx)
        stats.plot(i_ax, 'M \cdot w', label='M \cdot w', plot_idx=idx)
        stats.plot(i_ax, 'M0', label='M0', plot_idx=idx)
        i_ax.set_title('u decomp for {}'.format(method))
        i_ax.legend()

    # set titles and legends and limits and such
    # (note: `ylim()` is so useful! because sometimes one thing blows up and then autoscale messes up all plots)
    _ax = ax[0, 0]; _ax.set_title('learning rate'); _ax.legend()
    _ax = ax[0, 1]; _ax.set_title('disturbances'); _ax.legend()
    _ax = ax[1, 0]; _ax.set_title('instantaneous costs'); _ax.legend()
    _ax = ax[1, 1]; _ax.set_title('avg costs'); _ax.legend(); ylim(_ax, 0, 10000)
    _ax = ax[2, 0]; _ax.set_title('avg train losses since reset'); _ax.legend()
    _ax = ax[2, 1]; _ax.set_title('avg val losses since reset'); _ax.legend()
    _ax = ax[3, 0]; _ax.set_title('reconstructed states'); _ax.legend()
    _ax = ax[3, 1]; _ax.set_title('nn losses'); _ax.legend()  
    pass
plot(experiment)

### Dynamic Plot

#### dynamic plot
anim = render(experiment, 'lrs', 'train losses', sliderkey='lrs', save_path=None, duration=5)
vid = anim.to_html5_video()
HTML(vid)