# **Online [Nonstochastic](https://sites.google.com/view/nsc-tutorial/home) Control**

[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/MinRegret/nsc-tutorial/blob/main/online-control.ipynb) 

## Housekeeping
Imports [jax](https://github.com/google/jax), numpy, scipy, plotting utils...

In [1]:
#@title

import jax
import itertools
import numpy as onp
import jax.numpy as np
import matplotlib.pyplot as plt
import ipywidgets as widgets

from jax.numpy.linalg import inv, pinv
from scipy.linalg import solve_discrete_are as dare
from jax import jit, grad
from IPython import display
from toolz.dicttoolz import valmap, itemmap
from itertools import chain

def liveplot(costss, xss, cmax=30, cumcmax=15, wmax=2, xmax=20, logcmax=100, logcumcmax=1000):
    cummean = lambda x: np.cumsum(np.array(x))/np.arange(1, len(x)+1)
    cumcostss = valmap(cummean, costss)

    plt.style.use('seaborn')
    colors = {
        "Zero Control": "gray",
        "LQR / H2": "green",
        "Finite-horizon LQR / H2": "teal",
        "Optimal LQG for GRW": "aqua",
        "Robust / Hinf Control": "orange",
        "GPC": "red"
    }

    fig, ax = plt.subplots(3, 2, figsize=(21, 12))

    costssline = {}
    for Cstr, costs in costss.items():
        costssline[Cstr], = ax[0, 0].plot([], label=Cstr, color=colors[Cstr])
    ax[0, 0].set_xlabel("Time")
    ax[0, 0].set_ylabel("Instantaneous Cost")
    ax[0, 0].set_ylim([-1, cmax])
    ax[0, 0].set_xlim([0, 100])
    ax[0, 0].legend()

    cumcostssline = {}
    for Cstr, costs in cumcostss.items():
        cumcostssline[Cstr], = ax[0, 1].plot([], label=Cstr, color=colors[Cstr])
    ax[0, 1].set_xlabel("Time")
    ax[0, 1].set_ylabel("Average Cost")
    ax[0, 1].set_ylim([-1, cumcmax])
    ax[0, 1].set_xlim([0, 100])
    ax[0, 1].legend()

    perturbline, = ax[1, 0].plot([])
    ax[1, 0].set_xlabel("Time")
    ax[1, 0].set_ylabel("Perturbation")
    ax[1, 0].set_ylim([-wmax, wmax])
    ax[1, 0].set_xlim([0, 100])

    pointssline, trailssline = {}, {}
    for Cstr, C in xss.items():
        pointssline[Cstr], = ax[1,1].plot([], [], label=Cstr, color=colors[Cstr], ms=20, marker='s')
        trailssline[Cstr], = ax[1,1].plot([], [], label=Cstr, color=colors[Cstr], lw=2)
    ax[1, 1].set_xlabel("Position")
    ax[1, 1].set_ylabel("")
    ax[1, 1].set_ylim([-1, 7])
    ax[1, 1].set_xlim([-xmax, xmax])
    ax[1, 1].legend()

    logcostssline = {}
    for Cstr, costs in costss.items():
        logcostssline[Cstr], = ax[2, 0].plot([1], label=Cstr, color=colors[Cstr])
    ax[2, 0].set_xlabel("Time")
    ax[2, 0].set_ylabel("Instantaneous Cost (Log Scale)")
    ax[2, 0].set_xlim([0, 100])
    ax[2, 0].set_ylim([0.1, logcmax])
    ax[2, 0].set_yscale('log')
    ax[2, 0].legend()

    logcumcostssline = {}
    for Cstr, costs in cumcostss.items():
        logcumcostssline[Cstr], = ax[2, 1].plot([1], label=Cstr, color=colors[Cstr])
    ax[2, 1].set_xlabel("Time")
    ax[2, 1].set_ylabel("Average Cost (Log Scale)")
    ax[2, 1].set_xlim([0, 100])
    ax[2, 1].set_ylim([0.1, logcumcmax])
    ax[2, 1].set_yscale('log')
    ax[2, 1].legend()

    def livedraw(t):
        for Cstr, costsline in costssline.items():
            costsline.set_data(np.arange(t), costss[Cstr][:t])
        for Cstr, cumcostsline in cumcostssline.items():
            cumcostsline.set_data(np.arange(t), cumcostss[Cstr][:t])
        perturbline.set_data(np.arange(t), W[:t, 0])
        for i, (Cstr, pointsline) in enumerate(pointssline.items()):
            pointsline.set_data(xss[Cstr][t][0], i)
        for i, (Cstr, trailsline) in enumerate(trailssline.items()):
            trailsline.set_data(list(map(lambda x: x[0], xss[Cstr][max(t-10, 0):t])), i)
        for Cstr, logcostsline in logcostssline.items():
            logcostsline.set_data(np.arange(t), costss[Cstr][:t])
        for Cstr, logcumcostsline in logcumcostssline.items():
            logcumcostsline.set_data(np.arange(t), cumcostss[Cstr][:t])
        return chain(costssline.values(), cumcostssline.values(), [perturbline], pointssline.values(), trailssline.values(), logcostssline.values(), logcumcostssline.values())

    print("🧛 reanimating :) meanwhile...")
    livedraw(99)
    plt.show()

    from matplotlib import animation
    anim = animation.FuncAnimation(fig, livedraw, frames=100, interval=50, blit=True)
    from IPython.display import HTML
    display.clear_output(wait=True)
    return HTML(anim.to_html5_video())

## A simple dynamical system
Defines a discrete-time [double-integrator](https://en.wikipedia.org/wiki/Double_integrator) -- a simple linear dynamical system that mirrors 1d kinematics -- along with a quadratic cost.

Below $\mathbf{x}_t$ is the state, $\mathbf{u}_t$ is the control input (or action), $\mathbf{w}_t$ is the perturbation.

$$ \mathbf{x}_{t+1} = A\mathbf{x}_t + B\mathbf{u}_t + \mathbf{w}_t, \qquad c(\mathbf{x},\mathbf{u}) = \mathbf{x}^\top Q \mathbf{x} + \mathbf{u}^\top R \mathbf{u}$$

$$ A = \begin{bmatrix}
1 & 1\\
0 & 1
\end{bmatrix},\quad B = \begin{bmatrix}
0\\
1
\end{bmatrix}, \quad Q = \begin{bmatrix}
1 & 0\\
0 & 1
\end{bmatrix}, \quad R = \begin{bmatrix}
1
\end{bmatrix}$$

In [2]:
dx, du, T = 2, 1, 100
A, B = np.array([[1.0, 1.0], [0.0, 1.0]]), np.array([[0.0], [1.0]])
Q, R = np.eye(dx), np.eye(du)

dyn = lambda x, u, w, t: A @ x + B @ u + w
cost = lambda x, u, t: x.T @ A @ x + u.T @ R @ u

# A basic control loop. 
# (x, z) is the environ-controller state.
def eval(control, W):
    x, z = np.zeros(dx), None
    for t in range(T):
        u, z = control(x, z, t)
        c = cost(x, u, t)
        yield (x, u, W[t], c)
        x = dyn(x, u, W[t], t)



## Control Algorithms
The segment below puts forth a few basic control strategies.

+ **Zero Control**: Executes $\mathbf{u}=\mathbf{0}$.
+ **LQR / H2**: A discrete-time [linear-quadratic regulator](https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator).
+ **Finite-horizon LQR / H2**: A finite-horizon of the above.
+ **Robust / $H_\infty$ Control**: A worst-case [robust](https://en.wikipedia.org/wiki/H-infinity_methods_in_control_theory) controller.
+ **GPC**: [Gradient-perturbation](https://arxiv.org/abs/1902.08721) controller.

In [3]:
#@title

def zero():
    return lambda x, z, t: (np.zeros(du), z)


def h2(A=A, B=B, Q=Q, R=R):
    P = dare(A, B, Q, R)
    K = - inv(R + B.T @ P @ B) @ (B.T @ P @ A)
    return lambda x, z, t: (K @ x, z)


def h2nonstat(A=A, B=B, Q=Q, R=R, T=T):
    dx, du = B.shape
    P, K = [np.zeros((dx, dx)) for _ in range(T + 1)], [np.zeros((du, dx)) for _ in range(T)]
    P[T] = Q
    for t in range(T - 1, -1, -1):
        P[t] = Q + A.T @ P[t + 1] @ A - (A.T @ P[t + 1] @ B) @ inv(R + B.T @ P[t + 1] @ B) @ (B.T @ P[t + 1] @ A)
        K[t] = - inv(R + B.T @ P[t + 1] @ B) @ (B.T @ P[t + 1] @ A)
    return lambda x, z, t: (K[t] @ x, z)


def hinf(A=A, B=B, Q=Q, R=R, T=T, gamma=1.0):
    dx, du = B.shape
    P, K = [np.zeros((dx, dx)) for _ in range(T + 1)], [np.zeros((du, dx)) for _ in range(T)], 
    P[T] = Q
    for t in range(T - 1, -1, -1):
        Lambda = np.eye(dx) + (B @ inv(R) @ B.T - gamma ** -2 * np.eye(dx)) @ P[t + 1]
        P[t] = Q + A.T @ P[t + 1] @ pinv(Lambda) @ A
        K[t] = - np.linalg.inv(R) @ B.T @ P[t + 1] @ pinv(Lambda) @ A
    return lambda x, z, t: (K[t] @ x, z)


def gpc(A=A, B=B, Q=Q, R=R, T=T, H=3, M=3, lr=0.01, dyn=dyn, cost=cost):
    dx, du = B.shape
    P = dare(A, B, Q, R)
    K = - np.array(inv(R + B.T @ P @ B) @ (B.T @ P @ A))

    def proxy(E, off, W):
        y = np.zeros(dx)
        for h in range(H):
            v = -K @ y + np.tensordot(E, W[h: h + M], axes=([0, 2], [0, 1]))
            y = dyn(y, v, W[h + M], h + M)
        v = -K @ y + np.tensordot(E, W[h: h + M], axes=([0, 2], [0, 1]))
        c = cost(y, v, None)
        return c

    proxygrad = jit(grad(proxy, argnums=(0, 1)))

    def gpc_u(x, z, t):
        if z is None or t == 0:
            z = np.zeros(dx), np.zeros(du), np.zeros((H + M, dx)), np.zeros((M, du, dx)), np.zeros(du)
        xprev, uprev, W, E, off = z
        W = jax.ops.index_update(W, 0, x - A @ xprev - B @ uprev)
        W = np.roll(W, -1, axis=0)
        if t >= H + M:
            Edelta, offdelta = proxygrad(E, off, W)
            E -= lr * Edelta
            off -= lr * offdelta
        u = K @ x + np.tensordot(E, W[-M:], axes=([0, 2], [0, 1])) + off
        return u, (x, u, W, E, off)

    return gpc_u

def controllers(gamma, H, M, lr):
    return {
        "Zero Control": zero(),
        "LQR / H2": h2(),
        "Finite-horizon LQR / H2": h2nonstat(),
        "Robust / Hinf Control": hinf(gamma=gamma),
        "GPC": gpc(H=H, M=M, lr=lr),
    }

## Constant Perturbation
$$w = \texttt{magnitude} \times \begin{bmatrix}
1 \\
1
\end{bmatrix}$$

In [4]:
#@title Constant Pertrubation
#@markdown Environment Parameters
magnitude = 1.0 #@param {type:"slider", min:0, max:10, step:0.01}

W = magnitude * magnitude * np.ones((T, dx))

#@markdown Constant Pertrubation: Control parameters
hinf_log_gamma = -2 #@param {type:"slider", min:-7, max:5, step:0.01}
hinf_gamma = 10**(hinf_log_gamma)
gpc_lookback = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_memory = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_log_lr = -6 #@param {type:"slider", min:-10, max:-1, step:0.01}
gpc_lr = 10**(gpc_log_lr)

Cs = controllers(hinf_gamma, gpc_lookback, gpc_memory, gpc_lr)

print("🧛 evaluating controllers")
traces = {Cstr: list(zip(*eval(C, W))) for Cstr, C in Cs.items()}
xss = valmap(lambda x: x[0], traces)
uss = valmap(lambda x: x[1], traces)
costss = valmap(lambda x: x[3], traces)

liveplot(costss, xss, 30, 30, 2, 20, 10**7, 10**7)

## Sine Perturbation
$$w = \texttt{magnitude} \times \sin\left(\frac{2\pi t\times \texttt{freq}}{T}\right) \times \begin{bmatrix}
1 \\
1
\end{bmatrix}$$

In [5]:
#@title Sine Pertrubation
#@markdown Environment Parameters
magnitude = 1.0 #@param {type:"slider", min:0, max:5, step:0.01}
freq = 4.0 #@param {type:"slider", min:0, max:10, step:0.01}

W = magnitude * np.tile(np.sin(np.arange(T) * 2 * np.pi * freq / T), (2, 1)).T

#@markdown Constant Pertrubation: Control parameters
hinf_log_gamma = 1 #@param {type:"slider", min:-7, max:5, step:0.01}
hinf_gamma = 10**(hinf_log_gamma)
gpc_lookback = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_memory = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_log_lr = -5 #@param {type:"slider", min:-10, max:-1, step:0.01}
gpc_lr = 10**(gpc_log_lr)

Cs = controllers(hinf_gamma, gpc_lookback, gpc_memory, gpc_lr)

print("🧛 evaluating controllers")
traces = {Cstr: list(zip(*eval(C, W))) for Cstr, C in Cs.items()}
xss = valmap(lambda x: x[0], traces)
uss = valmap(lambda x: x[1], traces)
costss = valmap(lambda x: x[3], traces)

liveplot(costss, xss, 30, 15, 2, 20, 10**5, 10**5)

## Amplitude Modulation
$$w = \texttt{magnitude} \times \sin\left(\frac{2\pi t\times \texttt{freq}_1}{T}\right) \times \sin\left(\frac{2\pi t\times \texttt{freq}_2}{T}\right) \times \begin{bmatrix}
1 \\
1
\end{bmatrix}$$

In [6]:
#@title Amplitude Modulation
#@markdown Environment Parameters
magnitude = 1.0 #@param {type:"slider", min:0, max:5, step:0.01}
freq1 = 4.0 #@param {type:"slider", min:0, max:10, step:0.01}
freq2 = 3.0 #@param {type:"slider", min:0, max:10, step:0.01}

W = magnitude * np.multiply(np.tile(np.sin(np.arange(T) * 2 * np.pi * freq1 / T), (2, 1)), np.tile(np.sin(np.arange(T) * 2 * np.pi * freq2 / T), (2, 1))).T

#@markdown Constant Pertrubation: Control parameters
hinf_log_gamma = 1 #@param {type:"slider", min:-7, max:5, step:0.01}
hinf_gamma = 10**(hinf_log_gamma)
gpc_lookback = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_memory = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_log_lr = -5 #@param {type:"slider", min:-10, max:-1, step:0.01}
gpc_lr = 10**(gpc_log_lr)

Cs = controllers(hinf_gamma, gpc_lookback, gpc_memory, gpc_lr)

print("🧛 evaluating controllers")
traces = {Cstr: list(zip(*eval(C, W))) for Cstr, C in Cs.items()}
xss = valmap(lambda x: x[0], traces)
uss = valmap(lambda x: x[1], traces)
costss = valmap(lambda x: x[3], traces)

liveplot(costss, xss, 30, 15, 2, 20, 10**, 10**4)

## Uniform Perturbation
$$w \sim \begin{bmatrix} 
\textit{Unif}\left(0, 1\right)\\
\textit{Unif}\left(0, 1\right)
\end{bmatrix}$$

In [7]:
#@title Uniform Pertrubation
#@markdown Environment Parameters
magnitude = 1.0 #@param {type:"slider", min:0, max:10, step:0.01}

W = magnitude * onp.random.random((T, dx))

#@markdown Constant Pertrubation: Control parameters
hinf_log_gamma = 1 #@param {type:"slider", min:-7, max:5, step:0.01}
hinf_gamma = 10**(hinf_log_gamma)
gpc_lookback = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_memory = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_log_lr = -6 #@param {type:"slider", min:-10, max:-1, step:0.01}
gpc_lr = 10**(gpc_log_lr)

Cs = controllers(hinf_gamma, gpc_lookback, gpc_memory, gpc_lr)

print("🧛 evaluating controllers")
traces = {Cstr: list(zip(*eval(C, W))) for Cstr, C in Cs.items()}
xss = valmap(lambda x: x[0], traces)
uss = valmap(lambda x: x[1], traces)
costss = valmap(lambda x: x[3], traces)

liveplot(costss, xss, 50, 50, 4, 20, 10**2, 10**2)

## Gaussian Perturbation
$$w \sim \mathcal{N}\left(\mathbf{0}, \texttt{magnitude}^2 * \mathbf{I}\right)$$

For Gaussian noise, LQR is an optimal controller.

In [8]:
#@title Gaussian Pertrubation
#@markdown Environment Parameters
magnitude = 1.0 #@param {type:"slider", min:0, max:10, step:0.01}

W = magnitude * onp.random.normal(size=(T, dx))

#@markdown Constant Pertrubation: Control parameters
hinf_log_gamma = 2 #@param {type:"slider", min:-7, max:5, step:0.01}
hinf_gamma = 10**(hinf_log_gamma)
gpc_lookback = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_memory = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_log_lr = -6 #@param {type:"slider", min:-10, max:-1, step:0.01}
gpc_lr = 10**(gpc_log_lr)

Cs = controllers(hinf_gamma, gpc_lookback, gpc_memory, gpc_lr)

print("🧛 evaluating controllers")
traces = {Cstr: list(zip(*eval(C, W))) for Cstr, C in Cs.items()}
xss = valmap(lambda x: x[0], traces)
uss = valmap(lambda x: x[1], traces)
costss = valmap(lambda x: x[3], traces)

liveplot(costss, xss, 50, 50, 4, 20, 10**3, 10**3)

## Random-walk Perturbation
$$\Delta w_ t= w_{t+1}-w_{t} \sim \mathcal{N}\left(\mathbf{0}, \texttt{magnitude}^2 * \mathbf{I}\right)$$

This comparison also includes **LQG**, a [linear-quadratic gaussian](https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic%E2%80%93Gaussian_control) controller, that is optimal for Gaussian random-walk perturbations.


In [9]:
#@title Gaussian Pertrubation
#@markdown Environment Parameters
magnitude = 1.0 #@param {type:"slider", min:0, max:10, step:0.01}

W = magnitude * np.cumsum(onp.random.normal(size=(T, dx)), axis=0)

#@markdown Constant Pertrubation: Control parameters
hinf_log_gamma = 5 #@param {type:"slider", min:-7, max:5, step:0.01}
hinf_gamma = 10**(hinf_log_gamma)
gpc_lookback = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_memory = 5 #@param {type:"slider", min:1, max:20, step:1}
gpc_log_lr = -7.5 #@param {type:"slider", min:-10, max:-1, step:0.01}
gpc_lr = 10**(gpc_log_lr)

def h2ext(A=A, B=B, Q=Q, R=R, T=T):
    dx, du = B.shape
    Aext, Bext = np.block([[A, np.eye(dx)], [np.zeros((dx, dx)), np.eye(dx)]]), np.block([[B], [np.zeros((dx, du))]])
    Qext = np.block([[Q, np.zeros((dx, dx))], [np.zeros((dx, dx)), np.zeros((dx, dx))]])
    h2ns = h2nonstat(Aext, Bext, Qext, R, T)

    def h2ext_u(x, z, t):
        if z is None:
            z = np.zeros(dx), np.zeros(du)
        xprev, uprev = z
        w = x - A @ xprev - B @ uprev
        u, _ = h2ns(np.block([x, w]), None, t)
        return u, (x, u)

    return {"Optimal LQG for GRW": h2ext_u}

Cs = controllers(hinf_gamma, gpc_lookback, gpc_memory, gpc_lr)
Cs.update(h2ext())

print("🧛 evaluating controllers")
traces = {Cstr: list(zip(*eval(C, W))) for Cstr, C in Cs.items()}
xss = valmap(lambda x: x[0], traces)
uss = valmap(lambda x: x[1], traces)
costss = valmap(lambda x: x[3], traces)

liveplot(costss, xss, 1000, 1000, 20, 20, 10**6, 10**6)