In [7]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, grad

import numpy.random as rand
import seaborn as sns
import pandas as pd
from scipy.linalg import solve_discrete_are as dare
import matplotlib.pyplot as plt
from tqdm import tqdm

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import flax

In [8]:
# Quadratic Loss
def quad_loss(x, u, Q = None, R = None):
    x_contrib = x.T @ x if Q is None else x.T @ Q @ x
    u_contrib = u.T @ u if R is None else u.T @ R @ u
    
    return np.sum(x_contrib + u_contrib)

In [41]:
def buzz_noise(n, t, scale = 0.3):
    if(t < 2 * (T // 10)):
        return scale * (onp.sin(onp.arange(start=n*t, stop=n*(t+1))/(2*np.pi))).reshape((n, 1))
    elif(t < 4 * (T // 10)):
        return rand.normal(scale = scale, size = (n, 1))
    elif(t < 6 * (T // 10)):
        return scale * (onp.sin(onp.arange(start=n*t, stop=n*(t+1))/(2*np.pi))).reshape((n, 1))
    elif(t < 7 * (T // 10)):
        return rand.normal(scale = scale, size = (n, 1))
    else:
        return scale * (onp.sin(onp.arange(start=n*t, stop=n*(t+1))/(2*np.pi))).reshape((n, 1))

In [36]:
class LQR(flax.nn.Module):
    @classmethod
    def init_K(cls, T, A, B, Q=None, R=None):
        n, m = B[0].shape
        K = jnp.zeros((T, m, n))
        
        for t in range(T):
            if(t % 10 == 0):
                # Get system at current time
                At, Bt = A[t], B[t]
                Qt = jnp.eye(n, dtype=jnp.float32) if Q is None else Q[t]
                Rt = jnp.eye(m, dtype=jnp.float32) if R is None else R[t]

                # solve the ricatti equation 
                Xt = dare(At, Bt, Qt, Rt)

                #compute LQR gain
                Kt = jnp.linalg.inv(Bt.T @ Xt @ Bt + Rt) @ (Bt.T @ Xt @ At)
            K = jax.ops.index_update(K, t, Kt)
        return K
            
    def apply(self, x, T, A, B, K, Q=None, R=None):
        self.t = self.state("t", shape=())
        
        if self.is_initializing():
            self.t.value = 0
        
        action = -K[self.t.value] @ x
        self.t.value += 1
        
        return action

In [42]:
T = 1000
A = jnp.array([[[1., 1.], [0., 1.]] for t in range(T)])
B = jnp.array([[[0.], [2. + jnp.sin(np.pi * t/T)]] for t in range(T)])

n, m = 2, 1
x0 = np.zeros((n, 1))

buzz = jnp.asarray(np.asarray([buzz_noise(n, t) for t in range(T)]))

In [32]:
init_K = LQR.init_K(T, A, B)

In [37]:
model_def = LQR.partial(T=T, A=A, B=B, K=init_K)
with flax.nn.stateful() as state:
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [x0.shape])
model = flax.nn.Model(model_def, params)

In [45]:
def func(xstate, inputs):
    x, state = xstate
    a, b, z = inputs
    with flax.nn.stateful(state) as state:
        u = model(x)
        loss = quad_loss(x, u)
        x = a @ x + b @ u + z
    return (x, state), loss

In [46]:
%timeit x, loss = jax.lax.scan(func, (x0, state), (A, B, buzz))

51.9 ms ± 243 µs per loop (mean ± std. dev. of 7 runs, 10 loops each)
