# OpenAI Gym


In [1]:
import logging
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)  # set level to INFO for wordy
import matplotlib.pyplot as plt
from IPython.display import HTML

import numpy as np
import jax.numpy as jnp

import gymnasium as gym

from extravaganza.dynamical_systems import Gym

from extravaganza.sysid import Lifter
from extravaganza.controllers import NonlinearBPC, ConstantController, OfflineSysid
from extravaganza.observables import TimeDelayedObservation, FullObservation, PartialObservation
from extravaganza.rescalers import ADAM, D_ADAM, DoWG
from extravaganza.utils import ylim, render
from extravaganza.experiments import Experiment

# seeds for randomness. setting to `None` uses random seeds
SYSTEM_SEED = None
CONTROLLER_SEED = None
SYSID_SEED = None

INFO: Created a temporary directory at /var/folders/5m/0xr906c130vdqvkm3g21n6wr0000gn/T/tmp8t0qzzyv
INFO: Writing /var/folders/5m/0xr906c130vdqvkm3g21n6wr0000gn/T/tmp8t0qzzyv/_remote_module_non_scriptable.py


## System
Here, we work with games in the OpenAI gym, such as `MountainCarContinuous-v0`, in which we supply a value in `[-1, 1]` as a control to push a car left or right up a mountain. The tricky thing about this environment is that you first have to push the car up the left to gain momentum, even though the goal is on the right.

Another good environment is the `CartPole` enviroment (left and right bumps on a cart to keep an inverted pendulum upright), whose continuous analog is also displayed below.

We make use of **only the reward signal or cost function** to train, not using any state observation.

MountainCar | CartPole
- | - 
![mountaincar.gif](https://www.gymlibrary.dev/_images/mountain_car.gif) | ![cart_pole.gif](https://www.gymlibrary.dev/_images/cart_pole.gif)

## Hyperparameters

In [4]:
name = 'gymtesty'
filename = '../logs/{}.pkl'.format(name)

def get_experiment_args():
    # --------------------------------------------------------------------------------------
    # ------------------------    EXPERIMENT HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    num_trials = 1
    T = 20000  # total timesteps
    T0 = 10000  # number of timesteps to just sysid for our methods
    reset_condition = lambda t: False  # when to reset the system
    use_multiprocessing = False
    render_every = None

    # --------------------------------------------------------------------------------------
    # --------------------------    SYSTEM HYPERPARAMETERS    ------------------------------
    # --------------------------------------------------------------------------------------

    du = 1
    ds = 4
    env_name = 'CartPoleContinuous-v1'  
    # env_name = 'MountainCarContinuous-v0'
    make_system = lambda : Gym(env_name=env_name, repeat=1, max_episode_len=600, seed=SYSTEM_SEED)

    # --------------------------------------------------------------------------------------
    # ------------------------    LIFT/SYSID HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------
    
    dl = 16
    hh = 3
    
#     observable = TimeDelayedObservation(hh = hh, state_dim=ds, control_dim=du, use_states=True, use_costs=True, use_controls=True, use_time=False)
#     observable = PartialObservation(obs_dim = 3, state_dim=ds, seed=SYSTEM_SEED)
    observable = FullObservation(state_dim=ds)
    do = observable.obs_dim
    
    
    sysid_args = {
        'obs_dim': do,
        'control_dim': du,
        
        'max_traj_len': int(1e6),
        'exploration_scale': 0.4,

        'depth': 4,
        'h': 2,
        'num_epochs': 250,
        'batch_size': min(T0 - 3, 128),
        'lr': 0.002,
        
        'seed': SYSID_SEED,
    }

    # --------------------------------------------------------------------------------------
    # ------------------------    CONTROLLER HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    h = 5  # controller memory length (# of w's to use on inference)
    m_update_rescaler = lambda : ADAM(0.001, betas=(0.9, 0.999))
    m0_update_rescaler = lambda : ADAM(0.00, betas=(0.9, 0.999))
    k_update_rescaler = lambda : ADAM(0.001, betas=(0.9, 0.999))

    nonlinear_bpc_args = {
        'h': h,  
        'method': 'REINFORCE',
        'initial_scales': (0.2, 0.01, 0.),  # M, M0, K   (uses M0's scale for REINFORCE)
        'rescalers': (m_update_rescaler, m0_update_rescaler, k_update_rescaler),
        'bounds': None,
        'initial_u': jnp.zeros(du),
        'decay_scales': False,
        'use_tanh': False,
        'use_stabilizing_K': True,
        'seed': CONTROLLER_SEED
    }
    
    make_controllers = {
#         'Linear': lambda sys: OfflineSysid(NonlinearBPC(sysid=Lifter(method='identity', state_dim=ds, **sysid_args), **nonlinear_bpc_args), T0),
        'Lifted': lambda sys: OfflineSysid(NonlinearBPC(sysid=Lifter(method='nn', state_dim=dl, **sysid_args), **nonlinear_bpc_args), T0)
    }

    experiment_args = {
        'make_system': make_system,
        'make_controllers': make_controllers,
        'observable': observable,
        'num_trials': num_trials,
        'T': T,
        'reset_condition': reset_condition,
        'reset_seed': SYSTEM_SEED,
        'use_multiprocessing': use_multiprocessing,
        'render_every': render_every,
    }
    return experiment_args

## actually run the thing :)

In [5]:
# run
experiment = Experiment(name)
stats = experiment(get_experiment_args)

INFO: (EXPERIMENT) --------------------------------------------------
INFO: (EXPERIMENT) ----------------- TRIAL 0 -----------------------
INFO: (EXPERIMENT) --------------------------------------------------

INFO: (EXPERIMENT): testing Lifted
  gym.logger.warn(f"Box bound precision lowered by casting to {self.dtype}")
  0%|                                                                                                                      | 0/20000 [00:00<?, ?it/s]INFO: (EXPERIMENT): reset at t=0!
 50%|██████████████████████████████████████▎                                      | 9960/20000 [00:21<00:22, 454.31it/s, control=-.0161, cost=0.0102]INFO: (SYSID WRAPPER) ending exploration at timestep 10000
INFO: (LIFTER): ending sysid phase at step 9999
INFO: training!
INFO: mean loss for past 25 epochs was {'linearization': 426260.12459935894, 'simplification': 2.465606093406677, 'reconstruction': 0.0, 'controllability': 0.0}
INFO: mean loss for past 25 epochs was {'linearization': nan, 

In [None]:
# save args and stats!  --  note that to save the args, we actually save the `get_args` function. we can print the 
#                           source code later to see the hyperparameters we chose
# experiment.save(filename)

## Visualization
We keep track of the useful information through `Stats` objects, which can `register()` a variable to keep track of (which it does via calls to `update()`) and which can be aggregated via `Stats.aggregate()` for mean and variance statistics. 

We define below a plotting arrangement that plots all the desired quantities from both the system and controller.

In [None]:
def plot_gym(experiment: Experiment):
    assert experiment.stats is not None, 'cannot plot the results of an experiment that hasnt been run'
    all_stats = experiment.stats
    
    # clear plot and calc nrows
    plt.clf()
    n = 3
    nrows = n + (len(all_stats) + 1) // 2
    fig, ax = plt.subplots(nrows, 2, figsize=(16, 6 * nrows))

    # plot stats
    for i, (method, stats) in enumerate(all_stats.items()):
        if stats is None: 
            logging.warning('{} had no stats'.format(method))
            continue
            
        stats.plot(ax[0, 0], 'true states', label=method, plot_norm=True)
        stats.plot(ax[1, 0], 'costs', label=method)
        stats.plot(ax[1, 1], 'costs', label=method, plot_cummean=True)
        
        stats.plot(ax[2, 0], 'states', label=method, plot_norm=True)  # norm of the "state"
        stats.plot(ax[2, 1], 'linearization', label='linearization')  # various nn losses
        stats.plot(ax[2, 1], 'simplification', label='simplification')
        stats.plot(ax[2, 1], 'reconstruction', label='reconstruction')
        stats.plot(ax[2, 1], 'controllability', label='controllability')
            
        i_ax = ax[n + i // 2, i % 2]
        stats.plot(ax[0, 1], 'disturbances', label=method, plot_norm=True)
        stats.plot(i_ax, '-K @ state', label='-K @ state', plot_idx=0)
        stats.plot(i_ax, 'M \cdot w', label='M \cdot w', plot_idx=0)
        stats.plot(i_ax, 'M0', label='M0', plot_idx=0)
        i_ax.set_title('u decomp for {}'.format(method))
        i_ax.legend()

    # set titles and legends and limits and such
    # (note: `ylim()` is so useful! because sometimes one thing blows up and then autoscale messes up all plots)
    _ax = ax[0, 0]; _ax.set_title('true states'); _ax.legend()
    _ax = ax[0, 1]; _ax.set_title('disturbances'); _ax.legend()
    _ax = ax[1, 0]; _ax.set_title('instantaneous costs'); _ax.legend()
    _ax = ax[1, 1]; _ax.set_title('avg costs'); _ax.legend(); ylim(_ax, 0, 3)
    _ax = ax[2, 0]; _ax.set_title('reconstructed states'); _ax.legend()
    _ax = ax[2, 1]; _ax.set_title('nn losses'); _ax.legend()  
    pass

### Plot

In [None]:
plot_gym(experiment)

### Dynamic Plot

In [None]:
# dynamic plot
anim = render(experiment, 'xs', 'fs', sliderkey='us', save_path=None, duration=5, fps=30)
vid = anim.to_html5_video()
HTML(vid)