In [1]:
from copy import deepcopy
import tqdm
import matplotlib.pyplot as plt
import gymnasium as gym

import numpy as np
import jax.numpy as jnp
import torch

from dynamical_systems import LDS, COCO, LinearRegression, MNIST
from controllers import LQR, GPC, BPC, RBPC, LiftedBPC
from lifters import NoLift, RandomLift, LearnedLift
from sysid import SysID
from rescalers import IDENTITY, FIXED_RESCALE, EMA_RESCALE, ADAM, D_ADAM, DoWG



In [2]:
from controllers import Controller
from dynamical_systems import DynamicalSystem
from stats import Stats

def run_trial(controller: Controller, 
              system: DynamicalSystem, 
              T: int, 
              reset_every: int = None,
              reset_seed: int = None,
              wordy: bool = True):
    
    # initial control
    control = controller.initial_control if hasattr(controller, 'initial_control') else jnp.zeros(du)  
    
    # run trial
    pbar = tqdm.trange(T) if wordy else range(T)
    for t in pbar:
        if reset_every is not None and t % reset_every == 0:
            print('reset!')
            system.reset(reset_seed)
            
        cost, state = system.interact(control)  # state will be `None` for unobservable systems
        control = controller.get_control(cost, state)
        
        if (state is not None and jnp.any(jnp.isnan(state))) or (cost > 1e20):
            print('WARNING: state {} or cost {} diverged'.format(state, cost))
            return None, None
    
    return deepcopy(system.stats), deepcopy(controller.stats)

In [26]:
from dynamical_systems import DynamicalSystem
from utils import set_seed
from typing import Tuple

class MountainCar(DynamicalSystem):
    def __init__(self, 
                 repeat: int = 1,
                 render: bool = False,
                 seed: int = None):
        set_seed(seed)  # for reproducibility
        self.repeat = repeat
        self.render = render
        
        # env
        self.env = gym.make('MountainCarContinuous-v0', render_mode='human' if render else None)
        self.cost_fn = lambda state: max(0., 0.45 - state[0].item())
        self.initial_state, _ = self.env.reset()
        self.reset(seed)
        
        # stats to keep track of
        self.t = 0
        self.stats = Stats()
        self.stats.register('xs', float, plottable=True)
        self.stats.register('us', float, plottable=True)
        self.stats.register('fs', float, plottable=True)
        pass
        
    
    def reset(self, seed: int = None):
        """
        to reset an episode, which should send state back to init
        """
        set_seed(seed)  # for reproducibility
        self.state = jnp.array(self.initial_state.copy())
        self.done = False
        self.env.reset()
        return self
    

    def interact(self, control: jnp.ndarray) -> Tuple[float, jnp.ndarray]:
        """
        given control, returns cost and an observation. The observation may be the true state, a function of the state, or simply `None`
        """
        assert control.shape == (1,)

        for _ in range(self.repeat):
            if self.done:
                cost = 0.
                break
            self.state, _, self.done, _, _ = self.env.step(control)
            cost = self.cost_fn(self.state)
            if self.done:
                print('solved!')
        
        # update
        self.stats.update('xs', self.state[0].item(), t=self.t)
        self.stats.update('us', control.item(), t=self.t)
        # self.stats.update('ws', disturbance, t=self.t)
        self.stats.update('fs', cost, t=self.t)
        self.t += 1
        
        return cost, self.state
        
        

In [15]:
def plot_mountaincar_stats(all_system_stats, all_controller_stats={}):
    plt.clf()
    n = 4
    nrows = n + (len(all_controller_stats) + 1) // 2
    fig, ax = plt.subplots(nrows, 2, figsize=(16, 4 * nrows))

    for method, ss in all_system_stats.items():
        ss.plot(ax[0, 0], 'xs', label=method)
#         ss.plot(ax[0, 1], 'ws', label=method)
        ss.plot(ax[1, 0], 'us', label=method)
        ss.plot(ax[1, 1], 'fs', label=method)
    
    for i, (method, cs) in enumerate(all_controller_stats.items()):
        cs.plot(ax[2, 0], '||A||_op', label=method)
        cs.plot(ax[2, 1], '||B||_F', label=method)
        cs.plot(ax[3, 0], '||A-BK||_op', label=method)
        cs.plot(ax[3, 1], 'lifter losses', label=method)
        i_ax = ax[n + i // 2, i % 2]
        cs.plot(ax[0, 1], 'disturbances', label=method)
        cs.plot(i_ax, 'K @ state', label='K @ state')
        cs.plot(i_ax, 'M \cdot w', label='M \cdot w')
        cs.plot(i_ax, 'M0', label='M0')
        i_ax.set_title('u decomp for {}'.format(method))
        i_ax.legend()

    ax[0, 0].set_title('position'); ax[0, 0].legend()
    ax[0, 1].set_title('disturbances'); ax[0, 1].legend()
    ax[1, 0].set_title('controls'); ax[1, 0].legend()
    ax[1, 1].set_title('costs'); ax[1, 1].legend()
    
    ax[2, 0].set_title('||A||_op'); ax[2, 0].legend()
    ax[2, 1].set_title('||B||_F'); ax[2, 1].legend()
    
    ax[3, 0].set_title('||A-BK||_op'); ax[3, 0].legend()
    ax[3, 1].set_title('lifter losses'); ax[3, 1].legend()
    pass

In [30]:
SEED = None
h = 20  # controller memory length (# of w's to use on inference)
hh = 20  # history length of the cost/control histories
lift_dim = 20  # dimension to lift to

T = 400
T0 = 100
reset_every = 100

M_UPDATE_RESCALER = lambda : ADAM(0.01, betas=(0.9, 0.999))
M0_UPDATE_RESCALER = lambda : ADAM(0.01, betas=(0.9, 0.999))
K_UPDATE_RESCALER = lambda : ADAM(0.01, betas=(0.9, 0.999))

sysid_method = 'regression'
sysid_scale = 4

learned_lift_args = {
    'lift_lr': 0.001,
    'sysid_lr': 0.001,
    'cost_lr': 0.001,
    'depth': 3,
    'buffer_maxlen': 200,
    'batch_size': 20,
    'seed': SEED
}

lifted_bpc_args = {
    'h': h,
    'method': 'FKM',
    'initial_scales': (0.1, 0.1, 0.1),  # M, M0, K   (uses M0's scale for REINFORCE)
    'rescalers': (M_UPDATE_RESCALER, M0_UPDATE_RESCALER, K_UPDATE_RESCALER),
    'T0': T0,
    'bounds': (-1, 1),
    'initial_u': jnp.zeros(1),
    'step_every': 1,
    'decay_scales': False,
    'use_sigmoid': True,
    'K_every': 10000,
    'seed': SEED
}

# get controllers
ll = LearnedLift(hh, 1, lift_dim, scale=sysid_scale, **learned_lift_args)
controllers = {
#     'No Lift': LiftedBPC(lifter=NoLift(hh, 1, SEED), 
#                          sysid=SysID(sysid_method, 1, hh, sysid_scale), 
#                          **lifted_bpc_args),
#     'Random Lift': LiftedBPC(lifter=RandomLift(hh, 1, lift_dim, learned_lift_args['depth'], SEED), 
#                              sysid=SysID(sysid_method, 1, lift_dim, sysid_scale, SEED), 
#                              **lifted_bpc_args),
    'Learned Lift': LiftedBPC(lifter=ll, 
                              sysid=ll, 
                              **lifted_bpc_args)
}

# run trials and plot
all_system_stats, all_controller_stats = {}, {}
for key, controller in controllers.items():
    print(key)
    system = MountainCar(seed=SEED, render=True, repeat=5)
    ss, cs = run_trial(controller, system, T, wordy=True, reset_every=reset_every, reset_seed=SEED)
    if ss is None: continue
    all_system_stats[key] = ss
    all_controller_stats[key] = cs

plot_mountaincar_stats(all_system_stats, all_controller_stats);

Learned Lift


  0%|                                                                                                                  | 0/400 [00:00<?, ?it/s]

reset!


 25%|█████████████████████████▉                                                                               | 99/400 [00:28<01:28,  3.42it/s]

copying the K from <lifters.LearnedLift object at 0x2cf96f850>
solving DARE with unconstrained Q
||A||_op = 0.9965633749961853     ||B||_F 0.49984073638916016         ||A-BK||_op = 0.9891906380653381





ValueError: cannot convert float NaN to integer