In [None]:
from extravaganza.sysid import SysID
from extravaganza.lifters import Lifter
from extravaganza.utils import set_seed, append, sample, jkey, rescale, d_rescale, inv_rescale, opnorm, exponential_linspace_int
from extravaganza.models import MLP
from typing import Tuple
import torch
import jax.numpy as jnp
from collections import deque

class LearnedLift(Lifter, SysID):
    def __init__(self,
                 hh: int,
                 control_dim: int,
                 state_dim: int,
                 depth: int,
                 scale: float = 0.1,
                 lift_lr: float = 0.001,
                 sysid_lr: float = 0.001,
                 cost_lr: float = 0.001,
                 buffer_maxlen: int = int(1e9),
                 batch_size: int = 64,
                 num_epochs: int = 20,  # number of epochs over the buffer to use when querying `sysid()` or `dynamics()` for first time
                 seed: int = None):
        
        set_seed(seed)
        super().__init__(hh, control_dim, state_dim, seed)
        
        self.hh = hh
        self.control_dim = control_dim
        self.state_dim = state_dim
        self.scale = scale
        
        self.buffer = deque(maxlen=buffer_maxlen)
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        
        # to compute lifted states which hopefully respond linearly to the controls
        flat_dim = hh + control_dim * hh  # TODO could add control history as an input as well!
#         self.lift_model = MLP(layer_dims=exponential_linspace_int(flat_dim, self.state_dim, depth), 
#                             #   normalization = lambda dim: torch.nn.LayerNorm(dim),
#                               use_bias=False, 
#                               seed=seed).train().float()

        assert flat_dim == self.state_dim
        from INN.CouplingModels.RealNVP.linear import NonlinearRealNVP
        self.lift_model = NonlinearRealNVP(flat_dim, k=50)
        print('using invertible lifting model!')
        
        self.lift_opt = torch.optim.Adam(self.lift_model.parameters(), lr=lift_lr, weight_decay=0.0001)
    
        # to estimate linear dynamics of lifted states
        self.A = torch.nn.Parameter(0.99 * torch.eye(self.state_dim, dtype=torch.float32))  # for stability purposes :)
        self.B = torch.nn.Parameter(0.5 * torch.randn((self.state_dim, self.control_dim), dtype=torch.float32))
        self.sysid_opt = torch.optim.Adam([self.A, self.B], lr=sysid_lr, weight_decay=0.0001)
        
#         # to learn "inverse" of lifing function
#         depth = 3
#         self.cost_model = MLP(layer_dims=exponential_linspace_int(self.state_dim + self.control_dim, 1, depth),
#                               seed=seed, use_bias=True).train().float()
#         print('using depth of {} for cost model'.format(depth))
#         self.cost_opt = torch.optim.Adam(self.cost_model.parameters(), lr=cost_lr, weight_decay=0.0001)
        
        self.t = 1
        self.trained = False
        pass
    
    def get_cost(self, state, control):
        if len(state.shape) == 1: state = state.unsqueeze(0)
        histories = self.lift_model.inverse(state)
        cost = histories[:, self.hh - 1]
        return cost
    
    def forward(self, 
                cost_history: torch.Tensor,
                control_history: torch.Tensor) -> torch.Tensor:  # so that we don't need to return a jnp.ndarray
        inp = torch.cat((cost_history.reshape(-1, self.hh), control_history.reshape(-1, self.control_dim * self.hh)), dim=-1)
#         state = self.lift_model(inp)

        state = self.lift_model.forward(inp)[0]
    
        return state
    
    def map(self,
            cost_history: jnp.ndarray,
            control_history: jnp.ndarray) -> jnp.ndarray:
        """
        maps histories to lifted states. it's in pytorch rn, but that can change
        """
        assert cost_history.shape == (self.hh,)
        assert control_history.shape == (self.hh, self.control_dim)
        
        # convert to pytorch tensors and back rq  # TODO remove this eventually, making everything in jax
        with torch.no_grad():
            cost_history, control_history = map(lambda j_arr: torch.from_numpy(np.array(j_arr)).unsqueeze(0), [cost_history, control_history])
            state = jnp.array(self.forward(cost_history, control_history).squeeze(0).data.numpy())
        
        return state
    
    def perturb_control(self,
                        state: jnp.ndarray,
                        control: jnp.ndarray = None):
        assert state.shape == (self.state_dim,)
        eps = sample(jkey(), (self.control_dim,))  # random direction
        control = self.scale * eps
        self.t += 1
        return control
    
    def train(self):
            
        # prepare dataloader
        from torch.utils.data import DataLoader, TensorDataset
        controls = []
        prev_cost_history = []
        prev_control_history = []
        cost_history = []
        control_history = []
        for prev_histories, histories in self.buffer:  # append em all
            lists = [controls, prev_cost_history, prev_control_history, cost_history, control_history]
            vals = [histories[1][-1], *prev_histories, *histories]
            for l, v in zip(lists, vals): l.append(torch.from_numpy(np.array(v)))
        dataset = TensorDataset(*map(lambda l: torch.stack(l, dim=0), 
                                     [prev_cost_history, prev_control_history, cost_history, control_history, controls]))
        dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        
        losses = []
        print('training!')
        for t in range(self.num_epochs):
            for prev_cost_history, prev_control_history, cost_history, control_history, controls in dl:
                
                # compute disturbance
                prev_state = self.forward(prev_cost_history, prev_control_history)
                state = self.forward(cost_history, control_history)
                pred = self.A.expand(self.batch_size, self.state_dim, self.state_dim) @ prev_state.unsqueeze(-1) + \
                       self.B.expand(self.batch_size, self.state_dim, self.control_dim) @ controls.unsqueeze(-1)
                diff = state - pred.squeeze(-1)
                
                # update
                self.lift_opt.zero_grad()
                self.sysid_opt.zero_grad()
                
                # compute loss 
                LAMBDA_STATE_NORM, LAMBDA_STABILITY, LAMBDA_B_NORM = 1e-5, 0, 0
                state_norm = (1 / (torch.norm(state) + 1e-8))
                stability = opnorm(self.A - self.B @ dare_gain(self.A, self.B, torch.eye(self.state_dim), torch.eye(self.control_dim))) if LAMBDA_STABILITY > 0 else 0.
                B_norm = 1 / (torch.norm(self.B) + 1e-8)
                loss = torch.mean(diff ** 2) + LAMBDA_STATE_NORM * state_norm + LAMBDA_STABILITY * stability + LAMBDA_B_NORM * B_norm
                loss.backward()
                self.lift_opt.step()
                self.sysid_opt.step()
                
                losses.append(loss.item())
                
            print_every = 25
            if t % print_every == 0 or t == self.num_epochs - 1: print('mean loss for past {} epochs was {}'.format(print_every, np.mean(losses[-print_every:])))
        
        cost_losses = []
        for t in range(self.num_epochs):
            for prev_cost_history, prev_control_history, cost_history, control_history, controls in dl:
#                 self.cost_opt.zero_grad()
                state = self.forward(prev_cost_history, prev_control_history).reshape(self.batch_size, self.state_dim)
#                 control = control_history[:, -1].reshape(self.batch_size, self.control_dim).detach()
#                 inp = torch.cat((state, control), dim=-1)
#                 fhat = self.cost_model(inp)  # predict cost from state and control we played from that state
                fhat = self.get_cost(state, None)
                f = cost_history[:, -1]
                loss = torch.nn.functional.mse_loss(fhat.squeeze(), f.squeeze())
#                 loss.backward()
#                 self.cost_opt.step()
                cost_losses.append(loss.item())
            print_every = self.num_epochs
            if t % print_every == 0 or t == self.num_epochs - 1: 
                print('mean cost loss for past {} epochs was {}'.format(print_every, np.mean(cost_losses[-print_every:])))
#                 print(loss.item(), fhat, f)

        self.trained = True
        return losses
    
    def sysid(self):
        if not self.trained:
            self.losses = self.train()
        
        return jnp.array(self.A.data.numpy()), jnp.array(self.B.data.numpy())
    
    def update(self, 
               prev_histories: Tuple[jnp.ndarray, jnp.ndarray], 
               histories: Tuple[jnp.ndarray, jnp.ndarray]) -> float:
        self.buffer.append((prev_histories, histories))  # add the transition
        return 0.



In [1]:
from extravaganza.controllers import Controller
from extravaganza.stats import Stats

from typing import Tuple

class LiftedBPC(Controller):
    def __init__(self,
                 h: int,
                 initial_u: jnp.ndarray,
                 rescalers,
                 initial_scales: Tuple[float, float, float],
                 T0: int,
                 bounds = None,
                 method = 'REINFORCE',
                 lifter: Lifter = None,
                 sysid: SysID = None,
                 K: jnp.ndarray = None,
                 step_every: int = 1,
                 use_sigmoid = True,
                 decay_scales = False,
                 use_K_from_sysid: bool = False,
                 seed: int = None,
                 stats: Stats = None):

        set_seed(seed)  # for reproducibility
        
        # check things make sense
        if lifter is None and isinstance(sysid, Lifter): lifter = sysid
        if sysid is None and isinstance(lifter, SysID): sysid = lifter
        assert lifter.state_dim == sysid.state_dim
        assert lifter.control_dim == sysid.control_dim and initial_u.shape[0] == lifter.control_dim
        self.control_dim = lifter.control_dim
        self.state_dim = lifter.state_dim
        assert method in ['FKM', 'REINFORCE', 'ROLLOUT']
        assert all(map(lambda i: i >= 0, initial_scales))
        if bounds is not None:
            bounds = jnp.array(bounds).reshape(2, -1)
            assert len(bounds[0]) == len(bounds[1]) and len(bounds[0]) == self.control_dim, 'improper bounds'
            assert all(map(lambda i: bounds[0, i] < bounds[1, i], range(self.control_dim))), 'improper bounds'
        if K is not None:
            assert K.shape == (self.control_dim, self.state_dim)
        assert step_every < h, 'need to update at least every `h` steps'
        
        # hyperparams
        self.h = h
        self.hh = lifter.hh
        self.lifter = lifter
        self.sysid = sysid
        self.T0 = T0
        self.method = method
        self.bounds = bounds
        self.decay_scales = decay_scales
        self.use_K_from_sysid = use_K_from_sysid
        self.initial_control = initial_u
        self.step_every = step_every

        # for rescaling u
        self.rescale_u = lambda u: rescale(u, self.bounds, use_sigmoid=use_sigmoid) if self.bounds is not None else u
        self.inv_rescale_u = lambda ru: inv_rescale(ru, self.bounds, use_sigmoid=use_sigmoid) if self.bounds is not None else ru
        self.d_rescale_u = lambda u: d_rescale(u, self.bounds, use_sigmoid=use_sigmoid) if self.bounds is not None else jnp.ones_like(u)        
        
        # controller params
        self.M = jnp.zeros((self.h, self.control_dim, self.state_dim))
        self.M0 = self.inv_rescale_u(initial_u)
        self.K = K if K is not None else jnp.zeros((self.control_dim, self.state_dim)) # jax.random.normal(self.jkey(), shape=(self.control_dim, self.state_dim)) / (self.control_dim * self.state_dim)
        self.M_scale, self.M0_scale, self.K_scale = initial_scales
        
        # histories are rightmost recent (increasing in time)
        self.prev_cost = 0.
        self.prev_control = jnp.zeros(self.control_dim)
        self.prev_state = jnp.zeros(self.state_dim)
        self.disturbance_history = jnp.zeros((2 * self.h, self.state_dim))  # past 2h disturbances, for controller
        self.cost_history = jnp.zeros(self.hh)  # for sysid/lifting
        self.control_history = jnp.zeros((self.hh, self.control_dim))  # for sysid/lifting
        self.t = 1

        # grad estimation stuff -- NOTE maybe `self.eps` should be divided by its variance?
        if self.method == 'FKM':
            self.eps_M = jnp.zeros((self.h, self.h, self.control_dim, self.state_dim))  # noise history of M perturbations
            self.eps_M0 = jnp.zeros((self.h, self.control_dim))  # noise history of M0 perturbations
            self.eps_K = jnp.zeros((self.h, self.control_dim, self.state_dim))  # noise history of K perturbations
            
            def grad_fn(f, d: float = 1):
                grad_M = d.reshape(1, -1, 1) * f * jnp.sum(self.eps_M, axis=0) #* self.control_dim * self.state_dim * self.h
                grad_M0 = d.reshape(-1) * f * jnp.sum(self.eps_M0, axis=0) #* self.control_dim * self.state_dim * self.h
                grad_K = f * jnp.sum(self.eps_K, axis=0) #* self.control_dim * self.state_dim * self.h
                return grad_M, grad_M0, grad_K
            
        elif self.method == 'REINFORCE':
            self.eps = jnp.zeros((self.h + 1, self.control_dim))  # noise history of u perturbations
            
            def grad_fn(f, d: float = 1):
                val = sum([jnp.transpose(jnp.einsum('ij,k->ijk', self.disturbance_history[i: self.h + i], self.eps[i]), axes=(0, 2, 1)) for i in range(self.h)])
                grad_M = d.reshape(1, -1, 1) * f * val #* self.control_dim * self.state_dim * self.h
                grad_M0 = d.reshape(-1) * f * self.eps[-1] #* self.control_dim * self.state_dim * self.h
                val = self.eps[-1].reshape(self.control_dim, 1) @ self.prev_state.reshape(1, self.state_dim)
                grad_K = f * val #* self.control_dim * self.state_dim * self.h
                return grad_M, grad_M0, grad_K
            
        elif self.method == 'ROLLOUT':
            if not isinstance(self.lifter, LearnedLift): raise Exception('{} can only be used with learned lifters'.format(self.method))
            
            def grad_fn(f, d: float = 1):
                state = torch.from_numpy(np.array(self.initial_state))
                M, M0, K = map(lambda arr: torch.tensor(np.array(arr), requires_grad=True), [self.M, self.M0, self.K])
                act = lambda state, ws: self.rescale_u(self.inv_rescale_u(-K @ state) + M0 + torch.tensordot(M, ws, dims=([0, 2], [0, 1]))).reshape(self.control_dim)
                W = torch.from_numpy(np.array(self.disturbance_history))
                rollout_len = 5
                for t in range(rollout_len):
                    control = act(state, W[t: t + self.h])
                    state = (self.lifter.A @ state + self.lifter.B @ control + W[t + self.h]).reshape(self.state_dim)
#                 inp = torch.cat((state, act(state, W[rollout_len: rollout_len + self.h])), dim=0).unsqueeze(0)
                cost = self.lifter.get_cost(state, None).squeeze()
                cost.backward()
                grad_M = d.reshape(1, -1, 1) * M.grad.data.detach().cpu().numpy()
                grad_M0 = d.reshape(-1) * M0.grad.data.detach().cpu().numpy()
                grad_K = K.grad.data.detach().cpu().numpy()
                return grad_M, grad_M0, grad_K                    
            
        self.grads = deque([(jnp.zeros_like(self.M), jnp.zeros_like(self.M0), jnp.zeros_like(self.K))], maxlen=self.h)
        self.grad_fn = grad_fn
        
        self.M_update_rescaler = rescalers[0]()
        self.M0_update_rescaler = rescalers[1]()
        self.K_update_rescaler = rescalers[2]()
        
        # stats
        if stats is None:
            print('WARNING: no `Stats` object provided, so the controller will make a new one.')
            stats = Stats()
        self.stats = stats
        self.stats.register('||A-BK||_op', float, plottable=True)
        self.stats.register('||A||_op', float, plottable=True)
        self.stats.register('||B||_F', float, plottable=True)
        self.stats.register('disturbances', float, plottable=True)
        self.stats.register('lifter losses', float, plottable=True)
        if self.control_dim == 1:
            self.stats.register('K @ state', float, plottable=True)
            self.stats.register('M \cdot w', float, plottable=True)
            self.stats.register('M0', float, plottable=True)
        pass

# ------------------------------------------------------------------------------------------------------------
    
    def __call__(self, cost: float) -> jnp.ndarray:
        """
        Returns the control based on current cost and internal parameters.
        """
        # 1. observe next state and update histories
        prev_histories = (self.cost_history, self.control_history)
        self.cost_history = append(self.cost_history, cost)
        self.control_history = append(self.control_history, self.prev_control)
        self.t += 1
        histories = (self.cost_history, self.control_history)
        state = self.lifter.map(*histories)  # xhat_{t+1}
        if self.t < self.T0: 
            lifter_loss = self.lifter.update(prev_histories, histories)  # update lifter, if needed
        else:
            if self.t == self.T0: 
                print('WARNING: note that we are only updating lifter during sysid phase')
                self.initial_state = state.copy()
            lifter_loss = 0.
        
        # 2. explore for sysid, and then get stabilizing controller
        if self.t < self.T0:
            control = self.sysid.perturb_control(state)
            if self.bounds is not None: control = control.clip(*self.bounds)
            self.prev_state = state
            self.prev_control = control
            return control
        elif self.use_K_from_sysid and self.t == self.T0:  # get the K from sysid every so often
            print('copying the K from {}'.format(self.sysid))
            self.K = self.sysid.get_lqr()
            pass
    
        # 3. compute disturbance
        pred_state = self.sysid.dynamics(self.prev_state, self.prev_control)  # A @ xhat_t + B @ u_t
        disturbance = state - pred_state  # xhat_{t+1} - (A @ xhat_t + B @ u_t)
        self.disturbance_history = append(self.disturbance_history, disturbance)                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                                       
        
        # compute change in cost, as well as the new scale
        cost_diff = cost - self.prev_cost
        M_scale, M0_scale, K_scale = map(lambda s: s / (self.t ** 0.25) if self.decay_scales else s, [self.M_scale, self.M0_scale, self.K_scale])

        # 4. update controller
        d = self.d_rescale_u(self.prev_control)
        self.grads.append(self.grad_fn(cost_diff, d))
        if len(self.grads) == self.grads.maxlen and self.t % self.step_every == 0:
            grads = list(self.grads)[:self.step_every] if self.method != 'ROLLOUT' else list(self.grads)[-1:]  # use updates starting from h steps ago
            self.M = self.M - self.M_update_rescaler.step(sum([g[0] for g in grads]), iterate=self.M)
            self.M0 = self.M0 - self.M0_update_rescaler.step(sum([g[1] for g in grads]), iterate=self.M0)
            self.K = self.K - self.K_update_rescaler.step(sum([g[2] for g in grads]), iterate=self.K)
        
        # 5. compute newest perturbed control
        M_tilde, M0_tilde, K_tilde = self.M, self.M0, self.K

        if self.method == 'FKM':  # perturb em all
            eps_M = sample(jkey(), (self.h, self.control_dim, self.state_dim))
            eps_M0 = sample(jkey(), (self.control_dim,))
            eps_K = sample(jkey(), (self.control_dim, self.state_dim))
            M_tilde = M_tilde + M_scale * eps_M
            M0_tilde = M0_tilde + M0_scale * eps_M0
            K_tilde = K_tilde + K_scale * eps_K
            if M_scale > 0: self.eps_M = append(self.eps_M, eps_M)
            if M0_scale > 0: self.eps_M0 = append(self.eps_M0, eps_M0)
            if K_scale > 0: self.eps_K = append(self.eps_K, eps_K)
            
        elif self.method == 'REINFORCE':  # perturb output only
            eps = sample(jkey(), (self.control_dim,))
            M0_tilde = M0_tilde + M0_scale * eps
            if M0_scale > 0: self.eps = append(self.eps, eps / M0_scale)
        
        elif self.method == 'ROLLOUT':  # don't perturb at all
            pass
            
        # TODO this might not be the right thing to do with `K` when rescaling!!!!
        control = self.inv_rescale_u(-K_tilde @ state) + M0_tilde + jnp.tensordot(M_tilde, self.disturbance_history[-self.h:], axes=([0, 2], [0, 1]))        
        control = self.rescale_u(control)
#         control = self.sysid.perturb_control(state, control=control)  # perturb for sysid purposes later than T0?

        # cache it
        self.prev_cost = cost
        self.prev_state = state
        self.prev_control = control 
            
        # update stats
        A, B = self.sysid.sysid()
        self.stats.update('||A-BK||_op', opnorm(A - B @ self.K), t=self.t)
        self.stats.update('||A||_op', opnorm(A), t=self.t)
        self.stats.update('||B||_F', jnp.linalg.norm(B, 'fro').item(), t=self.t)
        self.stats.update('disturbances', jnp.linalg.norm(disturbance).item(), t=self.t)
        self.stats.update('lifter losses', lifter_loss, t=self.t)
        if self.control_dim == 1:
            self.stats.update('K @ state', (-self.K @ state).item(), t=self.t)
            self.stats.update('M \cdot w', (jnp.tensordot(self.M, self.disturbance_history[-self.h:], axes=([0, 2], [0, 1]))).item(), t=self.t)
            self.stats.update('M0', self.M0.item(), t=self.t)
            
        return control
    
    def get_control(self, cost: float, state: jnp.ndarray) -> jnp.ndarray:
        return self(cost)

NameError: name 'jnp' is not defined