# Definitions
As per the remark at the end of the writeup (I think its remark 10), we can enforce that our function $f: \mathcal{H}_1 \rightarrow \mathcal{H}_2$ is both $L$-subhomogenous and rotationally symmetric (w.r.t. the norm on $\mathcal{H}_2$) by constructing it according to the following decomposition:
$$f(v) = \varphi(||v||) \cdot \phi(v),$$
where $\varphi: \mathbb{R}_+ \rightarrow \mathbb{R}$ is an $L$-subhomogenous one-dimensional function and $\phi: \mathcal{H}_1 \rightarrow \mathbb{S}_{\mathcal{H}_2}$ is any function whose image is contained on the sphere in $\mathcal{H}_2$. In particular, we can represent $\varphi$ with a ReLU neural network with no biases! Importantly, $\phi$ can be as complex as we want!

In [1]:
import tqdm

import torch
import torch.nn as nn

import numpy as np

from extravaganza.models import MLP
from extravaganza.utils import set_seed

def exponential_linspace_int(start, end, num, divisible_by=1):
    """Exponentially increasing values of integers."""
    base = np.exp(np.log(end / start) / (num - 1))
    return [int(np.round(start * base**i / divisible_by) * divisible_by) for i in range(num)]

def uhoh(message):
    print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
    print('!!!!!!!!! UH OH !!!!!!!!!!!!')
    print('!!!!!!!!!! {} !!!!!!!!!!'.format(message))
    print('!!!!!!!!!!!!!!!!!!!!!!!!!!!!')
    pass

class NM(nn.Module):
    def __init__(self, in_dim: int, out_dim: int, depth: int, seed: int = None, activation = nn.ReLU):
        set_seed(seed)
        super().__init__()
        
        self.normnet = MLP(layer_dims=[1, 10, 10, 1],   # \varphi in the writeup 
                           activation=activation, 
                           use_bias=False)  # for homogeneity
        layer_dims = exponential_linspace_int(in_dim, out_dim, depth)
        self.dirnet = MLP(layer_dims=layer_dims,    # \phi in the writeup
                          activation=activation)
        pass
    
    def forward(self, x: torch.Tensor):
        assert x.ndim == 2, x.shape
        x_norm = torch.norm(x, dim=-1).unsqueeze(-1).type(x.dtype)
        varphi = self.normnet(x_norm)
        phi = self.dirnet(x)
        phi = phi / torch.norm(phi, dim=-1).unsqueeze(-1).type(x.dtype)  # ensure that phi lies on the unit sphere
        return varphi * phi


# make sure shapes appear ok
IN_DIM = 5
OUT_DIM = 128
DEPTH = 3
nm = NM(IN_DIM, OUT_DIM, DEPTH)
BATCH_SIZE = 17
x = torch.randn(BATCH_SIZE, IN_DIM)

if nm(x).shape == (BATCH_SIZE, OUT_DIM): print('seems ok')
else: uhoh('incorrect shape')

seems ok


# Testing for *L*-(sub)homogeneity
The definition we use is as follows:
### Definition
#### (Positive) Homogeneity and Subhomogeneity
We say that $f: \mathcal{H}_1 \rightarrow \mathcal{H}_2$ is (postiive) **$L$-homogenous** or **homogenous of order $L$** if for all $\gamma > 0$ and all $v \in \mathcal{H}_1$, we know that
    $$||f(\gamma v)||_{\mathcal{H}_2} = \gamma^L ||f(v)||_{\mathcal{H}_2}$$
    If instead all we know is that 
$$||f(\gamma v)||_{\mathcal{H}_2} \leq \gamma^L ||f(v)||_{\mathcal{H}_2},$$
we refer to $f$ as (positive) **$L$-subhomogeneous**. Note that this definition is weaker than the usual definition of positive homogeneity in that it only needs to hold w.r.t. the norm, allowing for arbitrary rotation and still allowing rich expressitivity.

We expect our neural network above to have the above property with $L = 1$

In [None]:
num_trials = 10000
L = 1.  # true value for L

# run trials
vals = []
for _ in tqdm.trange(num_trials):
    nm = NM(IN_DIM, OUT_DIM, DEPTH)

    # sample a random vector v
    x = torch.randn(1, IN_DIM)

    # sample a random positive scalar \gamma
    gamma = torch.rand(1) * torch.randint(50, size=(1,)) + 1e-8

    lhs = torch.norm(nm(gamma * x), dim=-1)
    norm_fv = torch.norm(nm(x), dim=-1)
    
    # change of base formula -- log_gamma(a) = ln(a) / ln(gamma)
    a = (lhs / norm_fv).squeeze()
    L_estimate = torch.log(a) / torch.log(gamma)
    vals.append(L_estimate.item())

# confirm
vals = np.array(vals)
vals = vals[~np.isnan(vals)]
mean, std = np.mean(vals), np.std(vals)
if abs(mean - L) < 1e-4 and std < 1e-2:
    print('seems ok')
elif std >= 1e-2:
    uhoh('fluctuating estimates for L, std={}'.format(std))
else:
    uhoh('we were pretty sure L={}'.format(mean))

# Testing for Rotational Symmetry
The definition we follow is as follows:
### Definition
#### Rotational Symmetry
We say that $f$ is rotationally symmetric in the sense that $||f \circ U||_{\mathcal{H}_2} \equiv ||f||_{\mathcal{H}_2}$ everywhere for all unitary transformations $U : \mathcal{H}_1 \rightarrow \mathcal{H}_1$ Note that this definition is weaker than the usual definition of rotational symmetry in that it doesn't require commuting with unitary operators, but instead it basically requires mapping spheres to spheres with arbitrary deformity, allowing rich expressitivity.

In [None]:
from scipy.stats import ortho_group

num_trials = 10000

# run trials
vals = []
for _ in tqdm.trange(num_trials):
    nm = NM(IN_DIM, OUT_DIM, DEPTH)

    # sample a random vector v
    x = torch.randn(IN_DIM)

    # sample a random unitary transformation U
    U = torch.tensor(ortho_group.rvs(IN_DIM)).type(torch.float32)

    lhs = torch.norm(nm((U @ x).unsqueeze(0)), dim=-1)
    rhs = torch.norm(nm(x.unsqueeze(0)), dim=-1)
    
    vals.append(abs((lhs - rhs).item()))

# confirm
mean, std = np.mean(vals), np.std(vals)
if mean < 1e-4 and std < 1e-4:
    print('seems ok')
elif std >= 1e-4:
    uhoh('fluctuating stuff, std={}'.format(std))
else:
    uhoh('we were pretty sure LHS-RHS={}'.format(mean))

# Testing Norm Monotonicity

The following proposition is the reason for this notebook:

### Proposition: 
#### *Suppose that $f$ is $L$-subhomogenous for some $L > 0$ and rotationally symmetric in the way described above. Then, $f$ is norm-monotonic.*

#### Proof
Let $x, y \in \mathcal{H}_1$ be arbitrary. Note that we can certainly find some unitary operator $U$ for which
        $$y = \left(\frac{||y||_{\mathcal{H}_1}}{||x||_{\mathcal{H}_1}}\right)U(x)$$
Let $\gamma :=\frac{||y||_{\mathcal{H}_1}}{||x||_{\mathcal{H}_1}} > 0$. Then, $y = \gamma  U(x)$. Since $f$ is $L$-subhomogenous, we can readily see that
$$\frac{||f(y)||_{\mathcal{H}_2}}{||f(x)||_{\mathcal{H}_2}} = \frac{||f(\gamma U(x)||_{\mathcal{H}_2}}{||f(x)||_{\mathcal{H}_2}} \leq \frac{\gamma^L ||f(U(x))||_{\mathcal{H}_2}}{||f(x)||_{\mathcal{H}_2}}$$ 
By the rotational symmetry of $f$ we know that $||f(U(x))||_{\mathcal{H}_2} = ||f(x)||_{\mathcal{H}_2}$, from which we find
    $$\frac{||f(y)||_{\mathcal{H}_2}}{||f(x)||_{\mathcal{H}_2}} \leq \gamma^L = \left(\frac{||y||_{\mathcal{H}_1}}{||x||_{\mathcal{H}_1}}\right)^L$$
This means that if $||f(y)||_{\mathcal{H}_2} > ||f(x)||_{\mathcal{H}_2}$, we immediately see that $\gamma > 1$ and therefore that $||y||_1 > ||x||_{\mathcal{H}_1}$. This is the condition for strict norm-monotonicity. $\blacksquare$


Since our function $f$ in this notebook satisfies the two above properties ($L$-subhomogeneity and rotational symmetry) w.r.t. the $||\cdot||_{\mathcal{H}_2}$ norm, the proposition should guarantee strict norm monotonicity. This is what we will test for below.

In [6]:
num_trials = 10000

# run trials
for _ in tqdm.trange(num_trials):
    nm = NM(IN_DIM, OUT_DIM, DEPTH)

    # sample two random vectors x and y
    x = torch.randn(1, IN_DIM)
    y = torch.randn(1, IN_DIM)
    fx = nm(x).squeeze(0)
    fy = nm(y).squeeze(0)
    
    n_x, n_y, n_fx, n_fy = map(lambda t: torch.norm(t).item(), [x, y, fx, fy])
    if n_fx < n_fy: assert n_x < n_y, (n_x, n_y, n_fx, n_fy)
    elif n_fx > n_fy: assert n_x > n_y, (n_x, n_y, n_fx, n_fy)
    elif n_fx > 0 or n_fy > 0: assert abs(n_x - n_y) < 1e-4, (n_x, n_y, n_fx, n_fy)
        
print('seems ok')

100%|████████████████████████████████████████████████████████████████████████████████████████████████████████| 10000/10000 [00:09<00:00, 1051.21it/s]

seems ok





# But is it a good NN?
Ok, so we have crafted a general and strictly norm-monotonic neural network architecture. But is it expressive enough to learn?

To answer this question we will use it to do some good ol' contrastive learning on MNIST and compare it with an equivalently-sized regular MLP.

In [None]:
import matplotlib.pyplot as plt

from torch.utils.data import DataLoader
import torchvision
import torchvision.transforms as T
from pytorch_metric_learning.losses import SupConLoss

BATCH_SIZE = 64
loss_fn = SupConLoss()

transform = T.Compose([
    T.ToTensor(),
    T.Lambda(lambda t: torch.flatten(t, start_dim=0))
])
train_dataset = torchvision.datasets.MNIST('../../data', train=False, transform=transform)
val_dataset = torchvision.datasets.MNIST('../../data', train=False, transform=transform)
train_dl = DataLoader(train_dataset, batch_size=BATCH_SIZE, shuffle=True, drop_last=True)
val_dl = DataLoader(val_dataset, batch_size=1, shuffle=False)

def train(model, opt, num_epochs: int):
    train_losses = []
    val_losses = []
    for _ in tqdm.trange(num_epochs):
        # train
        epoch_train_losses = []
        for x, y in train_dl:
            opt.zero_grad()
            emb = model(x)
            loss = loss_fn(emb, y)
            loss.backward()
            opt.step()
            epoch_train_losses.append(loss.item())
        train_losses.append(np.mean(epoch_train_losses))
        
        # val
        epoch_val_losses = []
        with torch.no_grad():
            for x, y in val_dl:
                emb = model(x)
                loss = loss_fn(emb, y)
                epoch_val_losses.append(loss.item())
            val_losses.append(np.mean(epoch_val_losses))
            
    return train_losses, val_losses
         
    
seed = None
dim = 128
depth = 8
num_epochs = 30
activation = nn.ReLU
set_seed(seed)

# make models
models = {
    'ours ReLU': NM(in_dim=28 * 28, out_dim=dim, depth=depth, activation=nn.ReLU),
    'ours Leaky': NM(in_dim=28 * 28, out_dim=dim, depth=depth, activation=nn.LeakyReLU),
    'default ReLU': MLP(layer_dims=exponential_linspace_int(28 * 28, dim, depth), activation=nn.ReLU),
    'default Leaky': MLP(layer_dims=exponential_linspace_int(28 * 28, dim, depth), activation=nn.LeakyReLU)
}

# run trials
results = {'train_losses': {}, 'val_losses': {}}
for k, model in models.items():
    print(k)
    opt = torch.optim.Adam(model.parameters(), lr=0.004)
    t, v = train(model, opt, num_epochs)
    results['train_losses'][k] = t
    results['val_losses'][k] = v

# plot
fig, ax = plt.subplots(1, 2)
for k in models.keys():
    t, v = results['train_losses'][k], results['val_losses'][k]
    ax[0].plot(range(len(t)), t, label=k)
    ax[1].plot(range(len(v)), v, label=k)
    
_ax = ax[0]; _ax.legend(); _ax.set_xlabel('epoch'); _ax.set_ylabel('loss'); _ax.set_title('train losses');
_ax = ax[1]; _ax.legend(); _ax.set_xlabel('epoch'); _ax.set_ylabel('loss'); _ax.set_title('val losses');

In [None]:
import pandas as pd
import seaborn as sns
from sklearn.manifold import TSNE

# check embedding spaces and t-SNE plot what's going on
fig, ax = plt.subplots(len(models), 1, figsize=(8, 8 * len(models)))
N = 1000 # how many points to actually plot in each fig, val dataset is 10k full

for _ax, (k, model) in zip(ax, models.items()):
    print(k)
    embs = []
    labels = []
    with torch.no_grad():
        for x, y in val_dl:
            embs.append(model(x).squeeze(0).data.numpy())
            labels.append(y.item())
    X = np.stack(embs, axis=0)
    idxs = np.random.permutation(len(X))[:N] 
    X = X[idxs]
    labels = [labels[i] for i in idxs]

    tsne = TSNE(n_components=2, learning_rate='auto',
                       init='random').fit_transform(X)

    df = pd.DataFrame()
    df['tsne-one'] = tsne[:, 0]
    df['tsne-two'] = tsne[:, 1]
    df['label'] = labels
    sns.scatterplot(
        x="tsne-one", y="tsne-two",
        hue="label",
        palette=sns.color_palette("hls", 10),  # 21 different speakers in the first 10k data points
        data=df,
        legend="full",
        alpha=0.3,
        ax=_ax
    )
    _ax.set_title('{} Embedding Space for MNIST Val Dataset'.format(k))
    
plt.show()

# Testing LQR and H_inf Controllers with Norm-Monotonic Lifter
The entire reason we were interested in neural nets with this property is to ensure that minimizing state norm of the lifted state corresponds to minimizing norm of the inputs!. Lifters with this property technically form an LDS with quadratic costs, lending themselves to provable control via LQR (optimal control, $H_{\infty}$ (robust control), and perhaps even GPC!

In [2]:
import logging
from abc import abstractmethod
import inspect
from typing import Tuple
from collections import deque

import numpy as np
import jax.numpy as jnp

from extravaganza.models import MLP
from extravaganza.lifters import Lifter
from extravaganza.sysid import SysID
from extravaganza.utils import exponential_linspace_int, sample, set_seed, jkey, opnorm, dare_gain, get_classname, least_squares

class NMLift(Lifter, SysID):
    def __init__(self,
                 hh: int,
                 control_dim: int,
                 state_dim: int,
                 depth: int,
                 scale: float = 0.1,
                 lift_lr: float = 0.001,
                 sysid_lr: float = 0.001,
                 buffer_maxlen: int = int(1e9),
                 batch_size: int = 64,
                 num_epochs: int = 20,  # number of epochs over the buffer to use when querying `sysid()` or `dynamics()` for first time
                 seed: int = None):
        
        set_seed(seed)
        super().__init__(hh, control_dim, state_dim, seed)
        
        self.hh = hh
        self.control_dim = control_dim
        self.state_dim = state_dim
        self.scale = scale
        
        self.buffer = deque(maxlen=buffer_maxlen)
        self.num_epochs = num_epochs
        self.batch_size = batch_size
        
        # to compute lifted states which hopefully respond linearly to the controls
        flat_dim = hh + control_dim * hh
        self.lift_model = NM(flat_dim, self.state_dim, depth, seed=seed).train().float() # TODO add layernorm to NM
        self.lift_opt = torch.optim.Adam(self.lift_model.parameters(), lr=lift_lr)
    
        # to estimate linear dynamics of lifted states
        self.A = torch.nn.Parameter(0.99 * torch.eye(self.state_dim, dtype=torch.float32))  # for stability purposes :)
        self.B = torch.nn.Parameter(torch.from_numpy(np.array(sample(jkey(), shape=(self.state_dim, self.control_dim), sampling_method='sphere'))))
        self.sysid_opt = torch.optim.Adam([self.A, self.B], lr=sysid_lr)
#         logging.warning('(LIFTER): note that right now we are NOT LEARNING B!!')
        
        self.t = 1
        self.trained = False
        pass
    
    def forward(self, 
                cost_history: torch.Tensor,
                control_history: torch.Tensor) -> torch.Tensor:  # so that we don't need to return a jnp.ndarray
        inp = torch.cat((cost_history.reshape(-1, self.hh), control_history.reshape(-1, self.control_dim * self.hh)), dim=-1)
        state = self.lift_model(inp)
        return state
    
    def map(self,
            cost_history: jnp.ndarray,
            control_history: jnp.ndarray) -> jnp.ndarray:
        """
        maps histories to lifted states. it's in pytorch rn, but that can change
        """
        assert cost_history.shape == (self.hh,)
        assert control_history.shape == (self.hh, self.control_dim)
        
        # convert to pytorch tensors and back rq  # TODO remove this eventually, making everything in jax
        with torch.no_grad():
            cost_history, control_history = map(lambda j_arr: torch.from_numpy(np.array(j_arr)).unsqueeze(0), [cost_history, control_history])
            state = jnp.array(self.forward(cost_history, control_history).squeeze(0).data.numpy())
        
        return state
    
    def perturb_control(self,
                        state: jnp.ndarray,
                        control: jnp.ndarray = None):
        # assert state.shape == (self.state_dim,)
        eps = sample(jkey(), (self.control_dim,))  # random direction
        control = self.scale * eps
        self.t += 1
        return control
    
    def train(self):
        logging.info('({}): training!'.format(get_classname(self)))
            
        # prepare dataloader
        from torch.utils.data import DataLoader, TensorDataset
        controls = []
        prev_cost_history = []
        prev_control_history = []
        cost_history = []
        control_history = []
        for prev_histories, histories in self.buffer:  # append em all
            lists = [controls, prev_cost_history, prev_control_history, cost_history, control_history]
            vals = [histories[1][-1], *prev_histories, *histories]
            for l, v in zip(lists, vals): l.append(torch.from_numpy(np.array(v)))
        dataset = TensorDataset(*map(lambda l: torch.stack(l, dim=0), 
                                     [prev_cost_history, prev_control_history, cost_history, control_history, controls]))
        dl = DataLoader(dataset, batch_size=self.batch_size, shuffle=True, drop_last=True)
        
        # learn lifter
        losses = []
        for t in tqdm.trange(self.num_epochs):
            for prev_cost_history, prev_control_history, cost_history, control_history, controls in dl:
                
                # compute disturbance
                prev_state = self.forward(prev_cost_history, prev_control_history)
                state = self.forward(cost_history, control_history)
                pred = self.A.expand(self.batch_size, self.state_dim, self.state_dim) @ prev_state.unsqueeze(-1) + \
                       self.B.expand(self.batch_size, self.state_dim, self.control_dim) @ controls.unsqueeze(-1)
                diff = state - pred.squeeze(-1)
                
                # update
                self.lift_opt.zero_grad()
                self.sysid_opt.zero_grad()
                
                # compute loss
                LAMBDA_STATE_NORM, LAMBDA_STABILITY, LAMBDA_B_NORM = 1e-5, 1e-5, 1e-5
                norm = torch.norm(state)
                state_norm = (1 / (norm + 1e-8)) + 1e-2 * norm if LAMBDA_STATE_NORM > 0 else 0.
                stability = opnorm(self.A - self.B @ dare_gain(self.A, self.B, torch.eye(self.state_dim), torch.eye(self.control_dim))) if LAMBDA_STABILITY > 0 else 0.
                B_norm = 1 / (torch.norm(self.B) + 1e-8) if LAMBDA_B_NORM > 0 else 0.
                loss = torch.norm(diff) + LAMBDA_STATE_NORM * state_norm + LAMBDA_STABILITY * stability + LAMBDA_B_NORM * B_norm
                loss.backward()
                self.lift_opt.step()
                self.sysid_opt.step()
                
                losses.append(loss.item())
                
            print_every = 25
            if t % print_every == 0 or t == self.num_epochs - 1: 
                logging.info('({}) \tmean loss for past {} epochs was {}'.format(get_classname(self), print_every, np.mean(losses[-print_every:])))
        
        # sysid
        states = []
        controls = []
        for prev_histories, histories in self.buffer:
            states.append(self.map(*prev_histories))
            controls.append(histories[1][-1])
        states, controls = np.array(states), np.array(controls)

        A, B = least_squares(states, controls)
        A, B = torch.from_numpy(np.array(A)), torch.from_numpy(np.array(B))
        
        def calc_loss(A, B):
            with torch.no_grad():
                l = 0.
                for prev_cost_history, prev_control_history, cost_history, control_history, controls in dl:
                    # compute disturbance
                    prev_state = self.forward(prev_cost_history, prev_control_history)
                    state = self.forward(cost_history, control_history)
                    pred = A.expand(self.batch_size, self.state_dim, self.state_dim) @ prev_state.unsqueeze(-1) + \
                        B.expand(self.batch_size, self.state_dim, self.control_dim) @ controls.unsqueeze(-1)
                    diff = state - pred.squeeze(-1)
                    
                    # compute loss
                    loss = torch.norm(diff)
                    l += loss.item()
                l /= len(dl)
                return l

            print('|A|', opnorm(self.A), 
              opnorm(A))
        print('|A-BK|', opnorm(self.A - self.B @ dare_gain(self.A, self.B, torch.eye(self.state_dim), torch.eye(self.control_dim))), 
              opnorm(A - B @ dare_gain(A, B, torch.eye(self.state_dim), torch.eye(self.control_dim))))
        print('|B|', torch.norm(self.B), torch.norm(B))
        print('old', calc_loss(self.A, self.B), 'new', calc_loss(A, B))
#         self.A, self.B = A, B
        logging.warning('(LIFTER): note that right now we are USING THE OLD SYSID FOR A, B!!')
            
        self.trained = True
        return losses
    
    def sysid(self):
        if not self.trained:
            self.losses = self.train()
            A, B = jnp.array(self.A.data.numpy()), jnp.array(self.B.data.numpy())
            logging.info('({}) ||A||_op = {}     ||B||_F {}'.format(get_classname(self), opnorm(A), jnp.linalg.norm(B, 'fro')))
            return A, B
        
        A, B = jnp.array(self.A.data.numpy()), jnp.array(self.B.data.numpy())
        return A, B
    
    def update(self, 
               prev_histories: Tuple[jnp.ndarray, jnp.ndarray], 
               histories: Tuple[jnp.ndarray, jnp.ndarray]) -> float:
        self.buffer.append((prev_histories, histories))  # add the transition
        return 0.

### Run Experiment

In [12]:
import logging
logging.basicConfig(format='%(levelname)s: %(message)s', level=logging.INFO)  # set level to INFO for wordy
import matplotlib.pyplot as plt
from IPython.display import HTML

import numpy as np
import jax.numpy as jnp

from extravaganza.dynamical_systems import LDS

from extravaganza.controllers import LiftedBPC, LambdaController
from extravaganza.lifters import NoLift, RandomLift, LearnedLift
from extravaganza.sysid import SysID
from extravaganza.controllers import LQR, HINF, BPC, GPC, RBPC
from extravaganza.rescalers import ADAM, D_ADAM, DoWG
from extravaganza.utils import ylim, render, append, opnorm, dare_gain, least_squares
from extravaganza.experiments import Experiment

# seeds for randomness. setting to `None` uses random seeds
SYSTEM_SEED = 23
CONTROLLER_SEED = None
LIFTER_AND_SYSID_SEED = None

name = 'lds_constant'
filename = '../logs/{}.pkl'.format(name)

def get_experiment_args():
    # --------------------------------------------------------------------------------------
    # ------------------------    EXPERIMENT HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    num_trials = 1
    T = 5000  # total timesteps
    T0 = 10000  # number of timesteps to just sysid for our methods
    reset_condition = lambda t: False  # how often to reset the system
    use_multiprocessing = False
    render_every = None

    # --------------------------------------------------------------------------------------
    # --------------------------    SYSTEM HYPERPARAMETERS    ------------------------------
    # --------------------------------------------------------------------------------------

    du = 1  # control dim
    ds = 1  # state dim

    disturbance_type = 'constant'
    cost_fn = 'quad'

    make_system = lambda : LDS(ds, du, disturbance_type, cost_fn, seed=SYSTEM_SEED)

    # --------------------------------------------------------------------------------------
    # ------------------------    LIFT/SYSID HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    sysid_method = 'regression'
    sysid_scale = 1.

    learned_lift_args = {
        'lift_lr': 0.004,
        'sysid_lr': 0.004,
        'depth': 12,
        'buffer_maxlen': int(1e6),
        'num_epochs': 6,
        'batch_size': 64,
        'seed': LIFTER_AND_SYSID_SEED
    }

    # --------------------------------------------------------------------------------------
    # ------------------------    CONTROLLER HYPERPARAMETERS    ----------------------------
    # --------------------------------------------------------------------------------------

    h = 5  # controller memory length (# of w's to use on inference)
    hh = 15  # history length of the cost/control histories
    lift_dim = 64  # dimension to lift to

    lifted_bpc_args = {
        'h': h,
        'method': 'REINFORCE',
        'initial_scales': (0.0, 0., 0.0),  # M, M0, K   (uses M0's scale for REINFORCE)
        'rescalers': (lambda : 0., lambda : 0., lambda : 0.),
        'T0': T0,
    #     'bounds': None,
        'initial_u': jnp.zeros(du),
        'decay_scales': False,
        'use_tanh': False,
        'use_K_from_sysid': True,
        'seed': CONTROLLER_SEED
    }
    none = LiftedBPC(lifter=NoLift(hh, du, LIFTER_AND_SYSID_SEED), sysid=SysID(sysid_method, du, hh, sysid_scale, LIFTER_AND_SYSID_SEED), **lifted_bpc_args)
    learned = LiftedBPC(lifter=LearnedLift(hh, du, lift_dim, scale=sysid_scale, **learned_lift_args), **lifted_bpc_args)
    nm = LiftedBPC(lifter=NMLift(hh, du, lift_dim, scale=sysid_scale, **learned_lift_args), **lifted_bpc_args)
    controllers = {
#         'None': none, 
#         'Learned': learned, 
        'NM': nm
    }

    dynamics = {'g.t.': (None, None)}
    for k, controller in controllers.items(): # interact in order to perform sysid
        # make system and get initial control
        system = make_system()
        control = controller.initial_control if hasattr(controller, 'initial_control') else jnp.zeros(du) 
        print(k)
        for t in tqdm.trange(T0):
            if reset_condition(t):
                logging.info('(EXPERIMENT): reset!')
                system.reset(reset_seed)
                    
            cost, state = system.interact(control)  # state will be `None` for unobservable systems
            control = controller.get_control(cost, state)
                    
            if (isinstance(state, jnp.ndarray) and jnp.any(jnp.isnan(state))) or (cost > 1e20):
                logging.error('(EXPERIMENT): state {} or cost {} diverged'.format(state, cost))
                assert False
                
        A, B = controller.sysid.sysid()
        dynamics[k] = (A, B)

    # test 
    make_controllers = {}
    for k, (A, B) in dynamics.items():
        
        def init(controller):
            controller._cost_history = jnp.zeros(hh,)
            controller._control_history = jnp.zeros((hh, du))
            pass
        
        def get_control(key):
            def _func(controller, cost, state):
                controller._cost_history = append(controller._cost_history, cost)
                if key.split()[0] != 'g.t.': state = controllers[key.split()[0]].lifter.map(controller._cost_history, controller._control_history)
                controller.stats.update('state_norm', jnp.linalg.norm(state).item(), t=controller.t)
                control = controller._controller.get_control(cost, state)
                controller._control_history = append(controller._control_history, control)
                return control
            return _func
        
        def get_controller(key, controller_class):
            def _func(sys):
                A, B = dynamics[key]
                if A is None: A = sys.A
                if B is None: B = sys.B
                Q = jnp.eye(A.shape[0])
                R = jnp.eye(du)
                controller = controller_class(A=A, B=B, Q=Q, R=R, seed=CONTROLLER_SEED)
                return LambdaController(controller, init_fn=init, get_control=get_control(key))
            return _func
        
        make_controllers.update({
            k + ' LQR': get_controller(k, LQR),
            k + ' HINF': get_controller(k, HINF),
            k + ' GPC': get_controller(k, GPC),
#             k + ' BPC': get_controller(k, BPC),
#             k + ' RBPC': get_controller(k, RBPC),
        })
    experiment_args = {
        'make_system': make_system,
        'make_controllers': make_controllers,
        'num_trials': num_trials,
        'T': T, 
        'reset_condition': reset_condition,
        'reset_seed': SYSTEM_SEED,
        'use_multiprocessing': use_multiprocessing,
        'render_every': render_every,
    }   
    return experiment_args

experiment = Experiment(name)
stats = experiment(get_experiment_args)

NM


100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 9987/10000 [00:28<00:00, 338.25it/s]
100%|█████████████████████████████████████████████████████████████████████████████████████████████████████████▊| 9987/10000 [00:41<00:00, 338.25it/s][A
 17%|██████████████████▊                                                                                              | 1/6 [04:48<24:00, 288.15s/it][A
 33%|█████████████████████████████████████▋                                                                           | 2/6 [09:58<20:04, 301.06s/it][A
 50%|████████████████████████████████████████████████████████▌                                                        | 3/6 [15:33<15:33, 311.31s/it][A
100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████▉| 9998/10000 [16:05<00:00, 10.35it/s]


LinAlgError: Failed to find a finite solution.

### Save Experiment

In [None]:
# # save args and stats!  --  note that to save the args, we actually save the `get_args` function. we can print the 
# #                           source code later to see the hyperparameters we chose
# experiment.save(filename)

### Plot

In [None]:
def plot_lds(experiment: Experiment):
    assert experiment.stats is not None, 'cannot plot the results of an experiment that hasnt been run'
    all_stats = experiment.stats
    
    # clear plot and calc nrows
    plt.clf()
    n = 5
    nrows = n + (len(all_stats) + 1) // 2
    fig, ax = plt.subplots(nrows, 2, figsize=(16, 6 * nrows))

    # plot stats
    for i, (method, stats) in enumerate(all_stats.items()):
#         if 'g.t.' not in method: continue
        if stats is None: 
            logging.warning('{} had no stats'.format(method))
            continue
        stats.plot(ax[0, 0], 'xs', label=method)
#         stats.plot(ax[0, 1], 'ws', label=method)
        stats.plot(ax[3, 1], 'us', label=method)
        stats.plot(ax[4, 0], 'state_norm', label=method)
        if 'costs' in stats:
            stats.plot(ax[1, 0], 'avg costs', label=method)
            stats.plot(ax[1, 1], 'costs', label=method)
        else:
            stats.plot(ax[1, 0], 'avg fs', label=method)
            stats.plot(ax[1, 1], 'fs', label=method)
    
        stats.plot(ax[2, 0], '||A||_op', label=method)
        stats.plot(ax[2, 1], '||B||_F', label=method)
        stats.plot(ax[3, 0], '||A-BK||_op', label=method)
        i_ax = ax[n + i // 2, i % 2]
        stats.plot(ax[0, 1], 'disturbances', label=method)
        stats.plot(i_ax, 'K @ state', label='K @ state')
        stats.plot(i_ax, 'M \cdot w', label='M \cdot w')
        stats.plot(i_ax, 'M0', label='M0')
        i_ax.set_title('u decomp for {}'.format(method))
        i_ax.legend()

    # set titles and legends and limits and such
    # (note: `ylim()` is so useful! because sometimes one thing blows up and then autoscale messes up all plots)
    _ax = ax[0, 0]; _ax.set_title('position'); _ax.legend()
    _ax = ax[0, 1]; _ax.set_title('disturbances'); _ax.legend()
    _ax = ax[1, 0]; _ax.set_title('avg costs'); _ax.legend()
    _ax = ax[1, 1]; _ax.set_title('costs'); _ax.legend()
    
    _ax = ax[2, 0]; _ax.set_title('||A||_op'); _ax.legend()
    _ax = ax[2, 1]; _ax.set_title('||B||_F'); _ax.legend()
    
    _ax = ax[3, 0]; _ax.set_title('||A-BK||_op'); _ax.legend()
    _ax = ax[3, 1]; _ax.set_title('controls'); _ax.legend()
    
    _ax = ax[4, 0]; _ax.set_title('state norm'); _ax.legend()
    pass

plot_lds(experiment)