In [1]:
from abc import abstractmethod
import inspect
from typing import Tuple
from collections import deque
from copy import deepcopy
import tqdm
import scipy

import torch
import numpy as np
import jax.numpy as jnp
"""
at the moment DL is in pytorch, once its finalized i can convert to jax for efficiency :)
"""

from models import MLP
from rescalers import FIXED_RESCALE, EMA_RESCALE, ADAM, D_ADAM, DoWG, IDENTITY
from stats import Stats

def _sigmoid(t):
    return 1 / (1 + np.exp(-t))

def _inv_sigmoid(s):
    return np.log(s / (1 - s + 1e-8))

def _d_sigmoid(t):
    s = _sigmoid(t)
    return s * (1 - s)

def _rescale(t, bounds, use_sigmoid):
    """
    rescales from `[0, 1] -> [tmin, tmax]`
    """
    tmin, tmax = bounds
    if use_sigmoid: t = _sigmoid(t)
    return tmin + (tmax - tmin) * t

def _d_rescale(t, bounds, use_sigmoid):
    tmin, tmax = bounds
    d = tmax - tmin
    if use_sigmoid: d *= _d_sigmoid(t)
    return d

def _inv_rescale(s, bounds, use_sigmoid):
    tmin, tmax = bounds
    t = (s - tmin) / (tmax - tmin)
    if use_sigmoid: t = _inv_sigmoid(t)
    return t

def _generate_uniform(shape, norm=1.00):
    v = np.random.normal(size=shape)
    v = norm * v / np.linalg.norm(v)
    v = jnp.array(v)
    return v

def _exponential_linspace_int(start, end, num, divisible_by=1):
    """Exponentially increasing values of integers."""
    base = np.exp(np.log(end / start) / (num - 1))
    return [int(np.round(start * base**i / divisible_by) * divisible_by) for i in range(num)]

def _append(arr, val):
    """
    rightmost recent appending, i.e. arr = (val_{t-h}, ..., val_{t-1}, val_t)
    """
    if isinstance(arr, torch.Tensor):
        if not isinstance(val, torch.Tensor):
            val = torch.tensor(val, dtype=arr.dtype)
        arr = torch.roll(arr, -1, dims=(0,)).clone()  # have to do it this way for autograd, but its not slower for some reason
        arr[-1] = val
    elif isinstance(arr, jnp.ndarray):
        if not isinstance(val, jnp.ndarray):
            val = jnp.array(val, dtype=arr.dtype)
        arr = arr.at[0].set(val)
        arr = jnp.roll(arr, -1, axis=0)
    return arr

KeyboardInterrupt: 

In [None]:
class SysID:
    """
    Determine `A` and `B` matrices for a LDS `x_{t+1} = A @ x_t + B @ u_t
    """
    def __init__(self,
                 method: str,  # must be one of ['HAZAN', 'REGRESSION']
                 control_dim: int,
                 state_dim: int,
                 scale: float):
        
        assert method in ['HAZAN', 'REGRESSION']
        self.method = method
        self.control_dim = control_dim
        self.state_dim = state_dim
        self.scale = scale
        
        self.control_history = []
        self.state_history = []
        self.eps_history = []
        
        self.t = 1
        self.A = self.B = None
        pass
    
    
    def perturb_control(self,
                        state: jnp.ndarray,
                        control: jnp.ndarray=None):
        """
        if `control` is not None, we perturb around this control only when 'HAZAN' mode. otherwise, we don't perturb
        """
        assert state.shape == (self.state_dim,)
        
        if self.method == 'HAZAN' or control is None:
            control = control if control is not None else jnp.zeros(self.control_dim)
#             eps = 1 - 2 * np.random.randint(2, size=(self.control_dim,))  # random rademacher direction
            eps = np.random.randn(self.control_dim)
            control = control + self.scale * eps / self.t ** 0.25
        else:
            eps = np.zeros(self.control_dim)

        self.control_history.append(control)
        self.state_history.append(state)
        self.eps_history.append(eps)
        self.t += 1
        return control
    
    
    def sysid(self):
        assert self.t > 1
        
        if self.A is not None: return self.A, self.B
        
        if self.method == 'HAZAN':
            k = int(0.15 * self.t)

            states = jnp.array(self.state_history)
            eps = jnp.array(self.eps_history)

            # prepare vectors and retrieve B
            scan_len = self.t - k - 1 # need extra -1 because we iterate over j = 0, ..., k
            N_j = jnp.array([jnp.dot(states[j + 1: j + 1 + scan_len].T, eps[:scan_len]) for j in range(k + 1)]) / scan_len
            B = N_j[0] # jnp.dot(states[1:].T, eps[:-1]) / (self.t - 1)

            # retrieve A
            C_0, C_1 = N_j[:-1], N_j[1:]
            C_inv = jnp.linalg.inv(jnp.tensordot(C_0, C_0, axes=((0, 2), (0, 2))) + 1e-3 * np.identity(self.state_dim))
            A = jnp.tensordot(C_1, C_0, axes=((0, 2), (0, 2))) @ C_inv

        elif self.method == 'REGRESSION':
            # transform x and u into regular numpy arrays for least squares
            states = np.array(self.state_history)
            controls = np.array(self.control_history)

            # regression on A and B jointly
            A_B = scipy.linalg.lstsq(np.hstack((states[:-1], controls[:-1])), states[1:])[0]
            A, B = jnp.array(A_B[:self.state_dim]).T, jnp.array(A_B[self.state_dim:]).T
        
        self.A, self.B = A, B
        return A, B
    
    
    def dynamics(self,
                 state: jnp.ndarray,
                 control: jnp.ndarray):
        assert state.shape == (self.state_dim,) and control.shape == (self.control_dim,)
        A, B = self.sysid()  # make sure we have an estimate first
        
        return A @ state + B @ control
        
    
    def get_lqr(self):
        A, B = self.sysid()  # make sure we have an estimate first
            
        # compute stabilizing controller for squared costs
        Q = jnp.eye(self.state_dim)
#         Q = jnp.zeros((self.state_dim, self.state_dim)).at[-1, -1].set(1.); print('solving DARE with constrained Q')
        R = jnp.eye(self.control_dim) * 1e-8  # heuristic to weight state more
        
        X = scipy.linalg.solve_discrete_are(A, B, Q, R)  # solve the ricatti equation
        K = jnp.linalg.inv(B.T @ X @ B + R) @ (B.T @ X @ A)  # compute LQR gain
        return K

In [None]:
class Lifter:
    """
    Map the past `hh` costs and controls to lifted "states".
    The given cost and control histories should be rightmost recent.
    """
    @abstractmethod
    def __init__(self,
                 hh: int,
                 control_dim: int,
                 state_dim: int):
        self.hh = hh
        self.control_dim = control_dim
        self.state_dim = state_dim
        pass
    
    @abstractmethod
    def map(self,
            cost_history: jnp.ndarray,
            control_history: jnp.ndarray) -> jnp.ndarray:
        """
        Maps a history of costs and controls to a "state" (which may just be the cost history or may be a 
        lifted state or anything).
        """
        raise NotImplementedError('{} not implemented'.format(inspect.stack()[0][3]))

    @abstractmethod
    def update(self,
               prev_state: jnp.ndarray,
               state: jnp.ndarray,
               control: jnp.ndarray,
               sysid: SysID) -> float:
        """
        Called with the previous state, the control that was applied, the resulting state, and the current
        estimate of system dynamics.
        Can be used to update the lifting mechanism, or can be a no-op.
        If updates, it can also return the loss.
        """
        return None
    
    
# ---------------------------------------------------------------------------------------------------

class NoLift(Lifter):
    """
    lift to state that is simply history of costs
    """
    def __init__(self,
                 hh: int,
                 control_dim: int):
        super().__init__(hh, control_dim, hh)
        pass
    
    def map(self,
            cost_history: jnp.ndarray,
            control_history: jnp.ndarray) -> jnp.ndarray:
        """
        maps histories to lifted states. it's in pytorch rn, but that can change
        """
        assert cost_history.shape == (self.hh,)
        assert control_history.shape == (self.hh, self.control_dim)
        
        return cost_history
    
# ---------------------------------------------------------------------------------------------------

class RandomLift(Lifter):
    """
    Uses a random init NN to transform the cost+control histories into a system that hopefully has linear dynamics.
    """
    def __init__(self,
                 hh: int,
                 control_dim: int,
                 state_dim: int,
                 depth: int):
        super().__init__(hh, control_dim, state_dim)
        
        self.hh = hh
        self.control_dim = control_dim
        self.state_dim = state_dim
        
        # to compute lifted states which hopefully respond linearly to the controls
        flat_dim = hh  # TODO could add control history as an input as well!
        self.lift_model = MLP(layer_dims=_exponential_linspace_int(flat_dim, self.state_dim, depth), use_bias=False).train().float()
        for p in self.lift_model.parameters(): p.data.uniform_(-0.1, 0.1)
        pass
    
    
    def map(self,
            cost_history: jnp.ndarray,
            control_history: jnp.ndarray) -> jnp.ndarray:
        """
        maps histories to lifted states. it's in pytorch rn, but that can change
        """
        assert cost_history.shape == (self.hh,)
        assert control_history.shape == (self.hh, self.control_dim)
        
        # convert to pytorch tensors and back rq
        with torch.no_grad():
            cost_history, control_history = map(lambda j_arr: torch.from_numpy(np.array(j_arr)), [cost_history, control_history])
            state = self.lift_model(cost_history.unsqueeze(0)).squeeze()
        state = jnp.array(state.cpu().data)
        
        return state

    def update(self,
               prev_state: jnp.ndarray,
               state: jnp.ndarray,
               control: jnp.ndarray,
               sysid: SysID):
        pass

# -------------------------------------------------------------------------------------------------------------
    
class LearnedLift(Lifter, SysID):
    def __init__(self,
                 hh: int,
                 control_dim: int,
                 state_dim: int,
                 depth: int,
                 scale: float=0.1,
                 lift_lr: float=0.001,
                 sysid_lr: float=0.001,
                 learn_lift: bool=True,
                 learn_sysid: bool=True):
        
        super().__init__(hh, control_dim, state_dim)
        
        self.hh = hh
        self.control_dim = control_dim
        self.state_dim = state_dim
        self.scale = scale
        self.learn_lift = learn_lift
        self.learn_sysid = learn_sysid
        
        # for pytorch learning while being able to return jnp arrays
        self.states = [torch.zeros(self.state_dim)]
        
        # to compute lifted states which hopefully respond linearly to the controls
        flat_dim = hh  # TODO could add control history as an input as well!
        self.lift_model = MLP(layer_dims=_exponential_linspace_int(flat_dim, self.state_dim, depth), use_bias=False).train().float()
#         for p in self.lift_model.parameters(): p.data.uniform_(-0.01, 0.01)
        self.lift_opt = torch.optim.Adam(self.lift_model.parameters(), lr=lift_lr)
    
        # to estimate linear dynamics of lifted states
        self.A = torch.nn.Parameter(torch.randn((self.state_dim, self.state_dim), dtype=torch.float32))
        self.B = torch.nn.Parameter(torch.randn((self.state_dim, self.control_dim), dtype=torch.float32))
        self.sysid_opt = torch.optim.Adam([self.A, self.B], lr=sysid_lr)
        
        self.t = 1
        pass
    
    def map(self,
            cost_history: jnp.ndarray,
            control_history: jnp.ndarray) -> jnp.ndarray:
        """
        maps histories to lifted states. it's in pytorch rn, but that can change
        """
        assert cost_history.shape == (self.hh,)
        assert control_history.shape == (self.hh, self.control_dim)
        
        # convert to pytorch tensors and back rq
        cost_history, control_history = map(lambda j_arr: torch.from_numpy(np.array(j_arr)), [cost_history, control_history])
        state = self.lift_model(cost_history.unsqueeze(0)).squeeze()
        self.states.append(state)
        state = jnp.array(state.data.numpy())
        
        return state

    def update(self,
               prev_state: jnp.ndarray,
               state: jnp.ndarray,
               control: jnp.ndarray,
               sysid: SysID) -> float:
        assert jnp.allclose(state, jnp.array(self.state.data.numpy()))
        assert len(self.states) >= 2
        assert isinstance(self.states[-1], torch.Tensor) and self.states[-1].requires_grad
        assert isinstance(self.states[-2], torch.Tensor) and self.states[-2].requires_grad
        
        if not isinstance(control, torch.Tensor): control = torch.Tensor(control).reshape(self.control_dim)
        
        prev_state, state = self.states[-2:]
        pred = self.A @ prev_state + self.B @ control
        diff = state - pred
        
        self.lift_opt.zero_grad()
        self.sysid_opt.zero_grad()
        loss = torch.mean(diff ** 2)
        loss.backward()
        if self.learn_lift: self.lift_opt.step()
        if self.learn_sysid: self.sysid_opt.step()
        return loss.item()
    
    def perturb_control(self,
                        state: jnp.ndarray,
                        control: jnp.ndarray=None):
        assert state.shape == (self.state_dim,)
        
        if control is None:
#             eps = 1 - 2 * np.random.randint(2, size=(self.control_dim,))  # random rademacher direction
            eps = np.random.randn(self.control_dim)
            control = self.scale * eps / self.t ** 0.25

        self.t += 1
        return control
    
    def sysid(self):
        return jnp.array(self.A.data.numpy()), jnp.array(self.B.data.numpy())
    
    def dynamics(self,
                 state: jnp.ndarray,
                 control: jnp.ndarray):
        assert state.shape == (self.state_dim,) and control.shape == (self.control_dim,)
        A, B = self.sysid()
        print(np.max(np.abs(np.linalg.eig(A)[0])))
        
        return A @ state + B @ control
        
    
    def get_lqr(self):
        A, B = self.sysid()
            
        # compute stabilizing controller for squared costs
        Q = jnp.eye(self.state_dim)
        R = jnp.eye(self.control_dim) * 1e-8  # heuristic to weight state more
        
        X = scipy.linalg.solve_discrete_are(A, B, Q, R)  # solve the ricatti equation
        K = jnp.linalg.inv(B.T @ X @ B + R) @ (B.T @ X @ A)  # compute LQR gain
        return K

In [None]:
class Controller:
    def __init__(self,
                 h: int,
                 initial_u: jnp.ndarray,
                 initial_scales: Tuple[float, float, float],
                 lifter: Lifter,
                 sysid: SysID,
                 T0: int,
                 bounds=None,
                 method='FKM',
                 K: jnp.ndarray=None,
                 use_sigmoid=True,
                 decay_scales=False):
        
        # check things make sense
        assert lifter.state_dim == sysid.state_dim
        assert lifter.control_dim == sysid.control_dim and initial_u.shape[0] == lifter.control_dim
        self.control_dim = lifter.control_dim
        self.state_dim = lifter.state_dim
        assert method in ['FKM', 'REINFORCE']
        assert all(map(lambda i: i >= 0, initial_scales))
        if bounds is not None:
            bounds = jnp.array(bounds).reshape(2, -1)
            assert len(bounds[0]) == len(bounds[1]) and len(bounds[0]) == self.control_dim, 'improper bounds'
            assert all(map(lambda i: bounds[0, i] < bounds[1, i], range(self.control_dim))), 'improper bounds'
        if K is not None:
            assert K.shape == (self.control_dim, self.state_dim)

        # hyperparams
        self.h = h
        self.hh = lifter.hh
        self.lifter = lifter
        self.sysid = sysid
        self.T0 = T0
        self.method = method
        self.bounds = bounds
        self.decay_scales = decay_scales

        # for rescaling u
        self.rescale_u = lambda u: _rescale(u, self.bounds, use_sigmoid=use_sigmoid) if self.bounds is not None else u
        self.inv_rescale_u = lambda ru: _inv_rescale(ru, self.bounds, use_sigmoid=use_sigmoid) if self.bounds is not None else ru
        self.d_rescale_u = lambda u: _d_rescale(u, self.bounds, use_sigmoid=use_sigmoid) if self.bounds is not None else jnp.ones_like(u)        
        
        # dynamic parameters of the controller
        self.M = jnp.zeros((self.h, self.control_dim, self.state_dim))
        self.M0 = self.inv_rescale_u(initial_u)
        self.K = K if K is not None else jnp.zeros((self.control_dim, self.state_dim)) # jnp.array(np.random.randn(self.control_dim, self.state_dim) / (self.control_dim * self.state_dim))
        
        self.M_scale, self.M0_scale, self.K_scale = initial_scales
        self.prev_cost = 0.
        self.prev_control = jnp.zeros(self.control_dim)
        self.prev_state = jnp.zeros(self.state_dim)
        self.cost_history = jnp.zeros(self.hh)   # histories are rightmost recent (increasing in time)
        self.control_history = jnp.zeros((self.hh, self.control_dim))
        self.state_history = jnp.zeros((self.h, self.state_dim))
        self.disturbance_history = jnp.zeros((2 * self.h, self.state_dim))  # past 2h disturbances
        self.t = 1

        # grad estimation stuff -- `self.eps` should be divided by its variance!!
        if self.method == 'FKM':
            self.eps_M = jnp.zeros((self.h, self.h, self.control_dim, self.state_dim))  # noise history of M perturbations
            self.eps_M0 = jnp.zeros((self.h, self.control_dim))  # noise history of M0 perturbations
            self.eps_K = jnp.zeros((self.h, self.control_dim, self.state_dim))  # noise history of K perturbations
            
            def grad_M(diff):
                return diff * jnp.sum(self.eps_M, axis=0) * self.control_dim * self.state_dim * self.h
            def grad_M0(diff):
                return diff * jnp.sum(self.eps_M0, axis=0) * self.control_dim * self.state_dim * self.h
            def grad_K(diff):
                return diff * jnp.sum(self.eps_K, axis=0) * self.control_dim * self.state_dim * self.h
            
        elif self.method == 'REINFORCE':
            self.eps = jnp.zeros((self.h + 1, self.control_dim))  # noise history of u perturbations
            
            def grad_M(diff):
#                 val = jnp.tensordot(self.eps[:self.h], self.disturbance_history[-self.h:], axes=(0, 0))
                val = 0.
                for i in range(self.h):
                    val += self.eps[i].reshape(self.control_dim, 1) @ self.disturbance_history[-(self.h + i + 1): -(i + 1)].reshape(1, self.h * self.state_dim)
                val = jnp.transpose(val.reshape(self.control_dim, self.h, self.state_dim), (1, 0, 2))
                return diff * val * self.control_dim * self.state_dim * self.h
            def grad_M0(diff):
                return diff * self.eps[-1] * self.control_dim * self.state_dim * self.h
            def grad_K(diff):
                val = 0.
                for i in range(self.h):
                    val += self.eps[i].reshape(self.control_dim, 1) @ self.state_history[i].reshape(1, self.state_dim)
                return diff * val * self.control_dim * self.state_dim * self.h
#             raise NotImplementedError('im not sure if my REINFORCE is right, please check it :)')
            
        self.grads = deque([(jnp.zeros_like(self.M), jnp.zeros_like(self.M0), jnp.zeros_like(self.K))], maxlen=self.h)
        self.grad_M = grad_M
        self.grad_M0 = grad_M0
        self.grad_K = grad_K
        
        self.M_update_rescaler = M_UPDATE_RESCALER()
        self.M0_update_rescaler = M0_UPDATE_RESCALER()
        self.K_update_rescaler = K_UPDATE_RESCALER()
    
        self.disturbance_rescaler = W_RESCALER()
        
        # stats
        self.stats = Stats()
        self.stats.register('rho(A)', float, plottable=True)  # operator norm (spectral radius) of A
        if self.control_dim == 1:
            if self.state_dim == 1:
                self.stats.register('A', float, plottable=True)
                self.stats.register('B', float, plottable=True)
            self.stats.register('disturbances', float, plottable=True)
            self.stats.register('K @ state', float, plottable=True)
            self.stats.register('M \cdot w', float, plottable=True)
            self.stats.register('M0', float, plottable=True)
        pass

# ------------------------------------------------------------------------------------------------------------
    
    def __call__(self, cost: float) -> jnp.ndarray:
        """
        Returns the control based on current cost and internal parameters.
        """
        
        # observe next state and update histories
        self.cost_history = _append(self.cost_history, cost)
        self.control_history = _append(self.control_history, self.prev_control)
        self.t += 1
        state = self.lifter.map(self.cost_history, self.control_history)  # xhat_{t+1}
        self.lifter.update(self.prev_state, state, self.prev_control, self.sysid)  # update lifter, if needed
        self.state_history = _append(self.state_history, state)
        
        # explore for sysid, and then get stabilizing controller
        if self.t < self.T0:
            control = self.sysid.perturb_control(state)
            self.prev_state = state
            self.prev_control = control
            return control
        elif self.t == self.T0 and USE_K:  # get stabilizing controller
            print('copying the K from {}'.format(self.sysid))
            self.K = self.sysid.get_lqr()
        
        # compute disturbance
        pred_state = self.sysid.dynamics(self.prev_state, self.prev_control)  # A @ xhat_t + B @ u_t
        disturbance = state - pred_state  # xhat_{t+1} - (A @ xhat_t + B @ u_t)
        disturbance = self.disturbance_rescaler.step(disturbance)
        self.disturbance_history = _append(self.disturbance_history, disturbance)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
        
        # compute change in cost, as well as the new scale
        cost_diff = cost - self.prev_cost
        M_scale, M0_scale, K_scale = map(lambda s: s / (self.t ** 0.25) if self.decay_scales else s, [self.M_scale, self.M0_scale, self.K_scale])

        # update controller
        d = self.d_rescale_u(self.prev_control)
        grad_M = self.grad_M(cost_diff) * d
        grad_M0 = self.grad_M0(cost_diff) * d
        grad_K = self.grad_K(cost_diff) * d
        self.grads.append((grad_M, grad_M0, grad_K))
        if len(self.grads) == self.grads.maxlen:
            grad_M, grad_M0, grad_K = self.grads[0]  # use update from h steps ago
            self.M = self.M - self.M_update_rescaler.step(grad_M, iterate=self.M)
            self.M0 = self.M0 - self.M0_update_rescaler.step(grad_M0, iterate=self.M0)
            self.K = self.K - self.K_update_rescaler.step(grad_K, iterate=self.K)
        
        # compute newest perturbed control
        M_tilde, M0_tilde, K_tilde = self.M, self.M0, self.K
        
        if self.method == 'FKM':  # perturb em all
            eps_M = _generate_uniform((self.h, self.control_dim, self.state_dim))
            eps_M0 = _generate_uniform(self.control_dim)
            eps_K = _generate_uniform((self.control_dim, self.state_dim))
            M_tilde = M_tilde + M_scale * eps_M
            M0_tilde = M0_tilde + M0_scale * eps_M0
            K_tilde = K_tilde + K_scale * eps_K
            if M_scale > 0: self.eps_M = _append(self.eps_M, eps_M)
            if M0_scale > 0: self.eps_M0 = _append(self.eps_M0, eps_M0)
            if K_scale > 0: self.eps_K = _append(self.eps_K, eps_K)
            
        elif self.method == 'REINFORCE':  # perturb output only
            eps = _generate_uniform(self.control_dim)
            M0_tilde = M0_tilde + M0_scale * eps
            if M0_scale > 0: self.eps = _append(self.eps, eps / M0_scale)
            
        control = -K_tilde @ state + M0_tilde + jnp.tensordot(M_tilde, self.disturbance_history[-self.h:], axes=([0, 2], [0, 1]))
        control = self.rescale_u(control)
#         control = self.sysid.perturb_control(state, control=control)  # perturb for sysid purposes

        # cache it
        self.prev_cost = cost
        self.prev_state = state
        self.prev_control = control 
        
        # update stats
        A, B = self.sysid.sysid()
        self.stats.update('rho(A)', jnp.max(jnp.abs(jnp.linalg.eig(A))).item(), t=self.t)
        self.stats.update('A', jnp.linalg.norm(A).item(), t=self.t)
        self.stats.update('B', jnp.linalg.norm(B).item(), t=self.t)
        if self.control_dim == 1:
            self.stats.update('disturbances', disturbance.item(), t=self.t)
            self.stats.update('K @ state', (-self.K @ state).item(), t=self.t)
            self.stats.update('M \cdot w', (jnp.tensordot(self.M, self.disturbance_history[-self.h:], axes=([0, 2], [0, 1]))).item(), t=self.t)
            self.stats.update('M0', self.M0.item(), t=self.t)
            
        return control

# LDS

In [None]:
class System:
    """
    LDS
    """
    def __init__(self, state_dim, control_dim):
        self.state_dim = state_dim
        self.control_dim = control_dim
        
        done = False
        while not done:
            self.A = torch.randn((state_dim, state_dim))
            w, _ = torch.linalg.eig(self.A)
            done = torch.max(torch.abs(w)) < 1 - 0.1
        
        done = False
        while not done:
            self.B = torch.randn((state_dim, control_dim))
            w, _ = torch.linalg.eig(self.B)
            done = torch.max(torch.abs(w)) < 1 - 0.1
        pass
        
    def step(self, state, control):
        if not isinstance(control, torch.Tensor):
            control = torch.from_numpy(np.array(control))
        if len(control.shape) > 1: control = control.squeeze(-1)
        s = self.A @ state + self.B @ control
        assert s.shape == (self.state_dim,), '{}  {}'.format(s.shape, self.state_dim)
        return s
    pass
  
def cost_fn(x, u):
    if isinstance(x, float):
        x = jnp.array(x)
    elif isinstance(x, torch.Tensor):
        x = jnp.array(x.detach().cpu().data)
    cost = (x ** 2).sum()# + (u ** 2).sum()
    return cost

def disturbance(t, dim):
    w = np.sin(4 * 2 * np.pi * t / T)
#     w = t
#     w = 1
#     w = 0.1 * np.random.randn()
#     w = 0
    return w * np.ones(dim, dtype=np.float32)

In [None]:
from deluca.agents._gpc import GPC
from deluca.agents._lqr import LQR
from deluca.agents._bpc import BPC

LIFT_LR = 0

M_UPDATE_RESCALER = lambda : ADAM(0.00, betas=(0.9, 0.999))
M0_UPDATE_RESCALER = lambda : ADAM(0.00, betas=(0.9, 0.999))
K_UPDATE_RESCALER = lambda : ADAM(0.0, betas=(0.9, 0.999))

W_RESCALER = lambda : IDENTITY()
# W_RESCALER = lambda : EMA_RESCALE(beta=0.99)
# W_RESCALER = lambda : ADAM(betas=(0.9, 0.9))

T = 500
T0 = 300
USE_K = True

du = 1  # control dim
ds = 1
h = 10  # controller memory length (# of w's to use on inference)
hh = 20  # history length of the cost/control histories
lift_dim = 40  # dimension to lift to

sysid_args = {
    'method': 'HAZAN',
    'scale': 0.1,
    'control_dim': du
}

controller_args = {
    'h': h,
    'method': 'FKM',
    'initial_scales': (0.0, 0.0, 0),  # M, M0, K   (uses M0's scale for REINFORCE)
    'T0': T0,
    'bounds': None,
    'initial_u': jnp.zeros(du),
    'decay_scales': False,
    'use_sigmoid': True
}

# torch.manual_seed(SEED)
sys_init = System(state_dim=ds, control_dim=du)
x_init = 1 * torch.randn(ds, dtype=torch.float32)  # initial state history

# --------------- LQR ------------------------------------------------------

Q = np.identity(ds)
R = np.identity(du) * 1e-8

sys = deepcopy(sys_init)
controller = LQR(A=sys.A.numpy(), B=sys.B.numpy(), Q=Q, R=R)
x = x_init.clone()

xs_lqr = []
costs_lqr = []
us_lqr = []
print('LQR')
for t in tqdm.trange(T):
    control = controller(x.numpy())
    x = sys.step(x, control) + disturbance(t, ds) 
    
    xs_lqr.append(x.numpy())
    if du == 1: us_lqr.append(control.item())
    costs_lqr.append(cost_fn(x, control).item())
    
# ----------------- GPC -----------------------------------------------------------

sys = deepcopy(sys_init)
controller = GPC(A=sys.A.numpy(), B=sys.B.numpy(), Q=Q, R=R, H=h, cost_fn=cost_fn, lr_scale=0.01, decay=False)
x = x_init.clone()

xs_gpc = []
costs_gpc = []
ws_gpc = []
us_gpc = []
grads_gpc = []
u_decomp_gpc = []
print('GPC')
for t in tqdm.trange(T):
    control = controller(x.numpy().reshape(-1, 1))
    x = sys.step(x, control) + disturbance(t, ds) 
    
    xs_gpc.append(x.numpy())
    costs_gpc.append(cost_fn(x, control).item())
    ws_gpc.append(jnp.linalg.norm(controller.noise_history[-1]))
    if du == 1: us_gpc.append(control.item())
    u_decomp_gpc.append((-controller.K @ controller.state,
                        jnp.tensordot(controller.M, controller.last_h_noises(), axes=([0, 2], [0, 1])),
                        jnp.zeros((du, 1))))
    
    
# ---------------- BPC ------------------------------------------------------------

sys = deepcopy(sys_init)
controller = BPC(A=sys.A.numpy(), B=sys.B.numpy(), Q=Q, R=R, H=h, lr_scale=0.005, delta=0.01)
x = x_init.clone()

xs_bpc = []
costs_bpc = []
ws_bpc = []
us_bpc = []
grads_bpc = []
u_decomp_bpc = []
print('BPC')
for t in tqdm.trange(T):
    cost = cost_fn(x, control).item()
    control = controller(x.numpy().reshape(-1, 1), cost)
    x = sys.step(x, control) + disturbance(t, ds) 
    
    xs_bpc.append(x.numpy())
    costs_bpc.append(cost)
    ws_bpc.append(jnp.linalg.norm(controller.noise_history[-1]))
    if du == 1: us_bpc.append(control.item())
    u_decomp_bpc.append((-controller.K @ controller.state,
                    jnp.tensordot(controller.M, controller.noise_history, axes=([0, 2], [0, 1])),
                    jnp.zeros((du, 1))))
    
# --------------- BPC NOLIFT ------------------------------------------------------

lifter = NoLift(hh, du)
sysid = SysID(state_dim=lifter.state_dim, **sysid_args)
controller = Controller(lifter=lifter, sysid=sysid, **controller_args)
sys = deepcopy(sys_init)
x = x_init.clone()

xs_bpc_nolift = []
costs_bpc_nolift = []
ws_bpc_nolift = []
us_bpc_nolift = []
grads_bpc_nolift = []
u_decomp_bpc_nolift = []
print('BPC NOLIFT')
for t in tqdm.trange(T):
    cost = cost_fn(x, control).item()
    control = controller(cost)
    x = sys.step(x, control) + disturbance(t, ds) 
    
    xs_bpc_nolift.append(x.numpy())
    costs_bpc_nolift.append(cost)
    ws_bpc_nolift.append(jnp.linalg.norm(controller.disturbance_history[-1]))
    if du == 1: us_bpc_nolift.append(control.item())
    grads_bpc_nolift.append(jnp.linalg.norm(controller.grads[0][1]))
    u_decomp_bpc_nolift.append((-controller.K @ controller.lifter.map(controller.cost_history, controller.control_history),
                               jnp.tensordot(controller.M, controller.disturbance_history[-controller.h:], axes=([0, 2], [0, 1])),
                               controller.M0))
    
# --------------- BPC LIFT ------------------------------------------------------

lifter = RandomLift(hh, du, lift_dim, depth=2)
sysid = SysID(state_dim=lifter.state_dim, **sysid_args)
controller = Controller(lifter=lifter, sysid=sysid, **controller_args)
sys = deepcopy(sys_init)
x = x_init.clone()

xs_bpc_lift = []
costs_bpc_lift = []
ws_bpc_lift = []
us_bpc_lift = []
grads_bpc_lift = []
u_decomp_bpc_lift = []
print('BPC LIFT')
for t in tqdm.trange(T):
    cost = cost_fn(x, control).item()
    control = controller(cost)
    x = sys.step(x, control) + disturbance(t, ds)
    
    xs_bpc_lift.append(x.numpy())
    costs_bpc_lift.append(cost)
    ws_bpc_lift.append(jnp.linalg.norm(controller.disturbance_history[-1]))
    if du == 1: us_bpc_lift.append(control.item())
    grads_bpc_lift.append(jnp.linalg.norm(controller.grads[0][1]))
    u_decomp_bpc_lift.append((-controller.K @ controller.lifter.map(controller.cost_history, controller.control_history),
                           jnp.tensordot(controller.M, controller.disturbance_history[-controller.h:], axes=([0, 2], [0, 1])),
                           controller.M0))

    
import matplotlib.pyplot as plt

fig, ax = plt.subplots(5, 2, figsize=(16, 36))

# plot lqr
xs_lqr = np.array(xs_lqr).squeeze()
ax[0, 0].plot(range(T), xs_lqr, label='lqr')
ax[0, 1].plot(range(T), costs_lqr, label='lqr')
ax[4, 0].plot(range(T), us_lqr, label='lqr')

# plot gpc
xs_gpc = np.array(xs_gpc).squeeze()
ws_gpc = np.array(ws_gpc).squeeze()
u_decomp_gpc = np.array(u_decomp_gpc).squeeze()
ax[0, 0].plot(range(T), xs_gpc, label='gpc')
ax[0, 1].plot(range(T), costs_gpc, label='gpc')
ax[1, 0].plot(range(T), ws_gpc, label='gpc')
# ax[1, 1].plot(range(T), grads_gpc, label='gpc')
ax[2, 0].plot(range(T), u_decomp_gpc[:, 0], label='K @ state')
ax[2, 0].plot(range(T), u_decomp_gpc[:, 1], label='M \cdot w')
ax[2, 0].plot(range(T), u_decomp_gpc[:, 2], label='M0')
ax[4, 0].plot(range(T), us_gpc, label='gpc')

# plot bpc
xs_bpc = np.array(xs_bpc).squeeze()
ws_bpc = np.array(ws_bpc).squeeze()
u_decomp_bpc = np.array(u_decomp_bpc).squeeze()
ax[0, 0].plot(range(T), xs_bpc, label='bpc')
ax[0, 1].plot(range(T), costs_bpc, label='bpc')
ax[1, 0].plot(range(T), ws_bpc, label='bpc')
# ax[1, 1].plot(range(T), grads_bpc, label='bpc')
ax[2, 1].plot(range(T), u_decomp_bpc[:, 0], label='K @ state')
ax[2, 1].plot(range(T), u_decomp_bpc[:, 1], label='M \cdot w')
ax[2, 1].plot(range(T), u_decomp_bpc[:, 2], label='M0')
ax[4, 0].plot(range(T), us_bpc, label='bpc')

# plot bpc nolift
xs_bpc_nolift = np.array(xs_bpc_nolift).squeeze()
ws_bpc_nolift = np.array(ws_bpc_nolift).squeeze()
u_decomp_bpc_nolift = np.array(u_decomp_bpc_nolift).squeeze()
ax[0, 0].plot(range(T), xs_bpc_nolift, label='bpc nolift')
ax[0, 1].plot(range(T), costs_bpc_nolift, label='bpc nolift')
ax[1, 0].plot(range(T), ws_bpc_nolift, label='bpc nolift')
ax[1, 1].plot(range(T), grads_bpc_nolift, label='bpc nolift')
ax[3, 0].plot(range(T), u_decomp_bpc_nolift[:, 0], label='K @ state')
ax[3, 0].plot(range(T), u_decomp_bpc_nolift[:, 1], label='M \cdot w')
ax[3, 0].plot(range(T), u_decomp_bpc_nolift[:, 2], label='M0')
ax[4, 0].plot(range(T), us_bpc_nolift, label='bpc nolift')

# plot bpc lift
xs_bpc_lift = np.array(xs_bpc_lift).squeeze()
ws_bpc_lift = np.array(ws_bpc_lift).squeeze()
u_decomp_bpc_lift = np.array(u_decomp_bpc_lift).squeeze()
ax[0, 0].plot(range(T), xs_bpc_lift, label='bpc lift')
ax[0, 1].plot(range(T), costs_bpc_lift, label='bpc lift')
ax[1, 0].plot(range(T), ws_bpc_lift, label='bpc lift')
ax[1, 1].plot(range(T), grads_bpc_lift, label='bpc lift')
ax[3, 1].plot(range(T), u_decomp_bpc_lift[:, 0], label='K @ state')
ax[3, 1].plot(range(T), u_decomp_bpc_lift[:, 1], label='M \cdot w')
ax[3, 1].plot(range(T), u_decomp_bpc_lift[:, 2], label='M0')
ax[4, 0].plot(range(T), us_bpc_lift, label='bpc lift')

ax[0, 0].plot(range(T), [0. for _ in range(T)], label='opt')
ax[0, 0].set_title('position'); ax[0, 0].legend(); ax[0, 0].scatter([0,], [x_init.item(),], marker=(5, 1));# ax[0, 0].set_ylim(-2, 2)
ax[0, 1].set_title('cost'); ax[0, 1].legend(); #ax[0, 1].set_ylim(-2, 2)
ax[1, 0].set_title('disturbances'); ax[1, 0].legend();# ax[1, 0].set_ylim(0, 2)
ax[1, 1].set_title('M0 grads'); ax[1, 1].legend()
ax[2, 0].set_title('K @ state, M \cdot w, and M_0 for GPC'); ax[2, 0].legend()
ax[2, 1].set_title('K @ state, M \cdot w, and M_0 for BPC'); ax[2, 1].legend()
ax[3, 0].set_title('K @ state, M \cdot w, and M_0 for BPC NOLIFT'); ax[3, 0].legend(); #ax[3, 0].set_ylim(-2, 2)
ax[3, 1].set_title('K @ state, M \cdot w, and M_0 for BPC LIFT'); ax[3, 1].legend(); #ax[3, 1].set_ylim(-2, 2)
ax[4, 0].set_title('controls'); ax[4, 0].legend(); #ax[4, 0].set_ylim(-2, 2)

plt.show()

# COCO BBOB

In [None]:
from dynamical_systems import COCO

LIFT_LR = 0

M_UPDATE_RESCALER = lambda : ADAM(0.00, betas=(0.9, 0.999))
M0_UPDATE_RESCALER = lambda : ADAM(0.0, betas=(0.9, 0.999))
K_UPDATE_RESCALER = lambda : ADAM(0.000, betas=(0.9, 0.999))

W_RESCALER = lambda : IDENTITY()
# W_RESCALER = lambda : EMA_RESCALE(beta=0.99)
# W_RESCALER = lambda : ADAM(betas=(0.9, 0.9))

T = 1000
T0 = 500
USE_K = True

du = 1  # control dim
ds = 1
h = 5  # controller memory length (# of w's to use on inference)
hh = 40  # history length of the cost/control histories
lift_dim = 128  # dimension to lift to

sysid_args = {
    'method': 'REGRESSION',
    'scale': 0.01,
    'control_dim': du
}

controller_args = {
    'h': h,
    'method': 'FKM',
    'initial_scales': (0.00, 0.00, 0.00),  # M, M0, K   (uses M0's scale for REINFORCE)
    'T0': T0,
    'bounds': (-1, 1),
    'initial_u': jnp.zeros(du),
    'decay_scales': False,
    'use_sigmoid': True
}

method = 'NOLIFT'  # 'NOLIFT', 'LIFT'
problem_number = 684# np.random.randint(2160) 
u_index = 0
predict_differences = True

# ------------------------------------------------------------------------------------------------------------------

lifter = NoLift(hh, du) if method == 'NOLIFT' else LinearLift(hh, du, lift_dim, depth=2)
sysid = SysID(state_dim=lifter.state_dim, **sysid_args)
controller = Controller(lifter=lifter, sysid=sysid, **controller_args)
sys = COCO(index=problem_number, u_index=u_index, predict_differences=predict_differences)
cost = sys.interact(0.)

xs = []
costs = []
ws = []
grads = []
u_decomp = []
for t in tqdm.trange(T):
    control = controller(cost)
    cost = sys.interact(control)
    
    xs.append(sys.x[sys.u_index])
    costs.append(cost)
    ws.append(jnp.linalg.norm(controller.disturbance_history[-1]))
    grads.append(jnp.linalg.norm(controller.grads[0][1]))
    u_decomp.append((-controller.K @ controller.lifter.map(controller.cost_history, controller.control_history),
                   jnp.tensordot(controller.M, controller.disturbance_history[-controller.h:], axes=([0, 2], [0, 1])),
                   controller.M0))
    

import matplotlib.pyplot as plt

fig, ax = plt.subplots(3, 2, figsize=(16, 10))
u_decomp = np.array(u_decomp)

# plot 
ax[0, 0].plot(range(T), xs, label=method)
ax[0, 1].plot(range(T), costs, label=method)
ax[1, 0].plot(range(T), ws, label=method)
ax[1, 1].plot(range(T), grads, label=method)

ax[0, 0].plot(range(T), [sys.stats['optimal_control']['value'] for _ in range(T)], label='opt')
ax[0, 0].set_title('position'); ax[0, 0].legend()
ax[0, 1].set_title('cost'); ax[0, 1].legend()
ax[1, 0].set_title('disturbances'); ax[1, 0].legend()
ax[1, 1].set_title('M0 grads'); ax[1, 1].legend()

ax[2, 0].plot(range(T), u_decomp[:, 0], label='K @ state')
ax[2, 0].plot(range(T), u_decomp[:, 1], label='M \cdot w')
ax[2, 0].plot(range(T), u_decomp[:, 2], label='M0')
ax[2, 0].set_title('K @ state, M \cdot w, and M_0'); ax[2, 0].legend()

us, fs = sys.stats['gt_controls']['value'], sys.stats['gt_values']['value']
ax[2, 1].plot(us, fs)
ax[2, 1].set_title('objective')

plt.show()