In [1]:
import jax
import jax.numpy as jnp
import numpy as np
from jax import jit, grad

import numpy.random as rand
import seaborn as sns
import pandas as pd
from scipy.linalg import solve_discrete_are as dare
import matplotlib.pyplot as plt
from tqdm import tqdm

from IPython.display import Image
from IPython.core.display import display, HTML
display(HTML("<style>.container { width:100% !important; }</style>"))

import flax

In [2]:
# Quadratic Loss
def quad_loss(x, u, Q = None, R = None):
    x_contrib = x.T @ x if Q is None else x.T @ Q @ x
    u_contrib = u.T @ u if R is None else u.T @ R @ u
    
    return np.sum(x_contrib + u_contrib)

In [3]:
def buzz_noise(n, t, scale = 0.3):
    if(t < 2 * (T // 10)):
        return scale * (jnp.sin(jnp.arange(start=n*t, stop=n*(t+1))/(2*np.pi))).reshape((n, 1))
    elif(t < 4 * (T // 10)):
        return rand.normal(scale = scale, size = (n, 1))
    elif(t < 6 * (T // 10)):
        return scale * (jnp.sin(jnp.arange(start=n*t, stop=n*(t+1))/(2*np.pi))).reshape((n, 1))
    elif(t < 7 * (T // 10)):
        return rand.normal(scale = scale, size = (n, 1))
    else:
        return scale * (jnp.sin(jnp.arange(start=n*t, stop=n*(t+1))/(2*np.pi))).reshape((n, 1))

In [106]:
def lifetime(x):
  l = 16
  while x % 2 == 0:
    l *= 2
    x /= 2

  return min(T//8, l + 1)

In [4]:
class LQR(flax.nn.Module):
    @classmethod
    def init_K(cls, T, A, B, Q=None, R=None):
        n, m = B[0].shape
        K = jnp.zeros((T, m, n))
        
        for t in range(T):
            if(t % 10 == 0):
                # Get system at current time
                At, Bt = A[t], B[t]
                Qt = jnp.eye(n, dtype=jnp.float32) if Q is None else Q[t]
                Rt = jnp.eye(m, dtype=jnp.float32) if R is None else R[t]

                # solve the ricatti equation 
                Xt = dare(At, Bt, Qt, Rt)

                #compute LQR gain
                Kt = jnp.linalg.inv(Bt.T @ Xt @ Bt + Rt) @ (Bt.T @ Xt @ At)
            K = jax.ops.index_update(K, t, Kt)
        return K
            
    def apply(self, x, T, A, B, K, Q=None, R=None):
        self.t = self.state("t")
        
        if self.is_initializing():
            self.t.value = 0
        
        action = -K[self.t.value] @ x
        self.t.value += 1
        
        return action

In [5]:
T = 1000
A = jnp.array([[[1., 1.], [0., 1.]] for t in range(T)])
B = jnp.array([[[0.], [2. + jnp.sin(np.pi * t/T)]] for t in range(T)])

n, m = 2, 1
x0 = jnp.zeros((n, 1))

buzz = jnp.asarray(np.asarray([buzz_noise(n, t) for t in range(T)]))



In [6]:
init_K = LQR.init_K(T, A, B)

In [7]:
model_def = LQR.partial(T=T, A=A, B=B, K=init_K)
with flax.nn.stateful() as state:
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [x0.shape])
lqr = flax.nn.Model(model_def, params)

In [8]:
def func(carry, inputs):
    x, state, model = carry
    a, b, z = inputs
    with flax.nn.stateful(state) as state:
        u = model(x)
        loss = quad_loss(x, u)
        x = a @ x + b @ u + z
    return (x, state, model), loss

In [9]:
(x, state, lqr), loss = jax.lax.scan(func, (x0, state, lqr), (A, B, buzz))

In [191]:
class GPC(flax.nn.Module):
    def apply(self, x, T, A, B, u=None, Q=None, R=None, K=None, start_time = 0, cost_fn = quad_loss, \
        H = 3, HH = 2, lr_scale = 0.0001, lr_scale_decay = 1.0, decay = False, include_bias = True):
        """
        Description: Initialize the dynamics of the model
        Args:
            A,B (float/numpy.ndarray): system dynamics
            H (postive int): history of the controller 
            HH (positive int): history of the system 
            K (float/numpy.ndarray): Starting policy (optional). Defaults to LQR gain.
            x (float/numpy.ndarray): initial state (optional)
        """
        n, m = B[0].shape # State & Action Dimensions

        T -= start_time
        
        self.t = self.state("t")
        self.w = self.state("w", shape=(H + HH, n, 1), initializer=flax.nn.initializers.zeros)
        self.x = self.state("x", shape=(n, 1), initializer=flax.nn.initializers.zeros)
        self.u = self.state("u", shape=(m, 1), initializer=flax.nn.initializers.zeros)
        self.M = self.state("M", (H, m, n), initializer=flax.nn.initializers.zeros)
        self.bias = self.state("bias", (m, 1), initializer=flax.nn.initializers.zeros)
        
        if self.is_initializing():
            self.t.value = 0
        
        action = -K[self.t.value] @ x
        action += jnp.tensordot(self.M.value, self.w.value[-H:], axes=([0, 2], [0, 1]))
        action += self.bias.value * include_bias
                
        # The Surrogate Cost Function
        def policy_loss(M, bias, w, t):
            y = jnp.zeros((n, 1))
            t0 = t - HH + 1
            for h in range(HH - 1):
                v = -K[t0 + h] @ y 
                v += jnp.tensordot(M, w[h : h + H], axes = ([0, 2], [0, 1])) 
                v += bias
                y = A[t0 + h] @ y + B[t0 + h] @ v + w[h + H]
            # Don't update state at the end    
            v = -K[t] @ y + jnp.tensordot(M, w[h : h + H], axes=([0, 2], [0, 1])) + bias
            return cost_fn(y, v)
        
        if not self.is_initializing():
            print(self.t.value, HH)
            # 1. Get gradients
            delta_M, delta_bias = grad(policy_loss, (0, 1))(self.M.value,
                                                            self.bias.value,
                                                            self.w.value,
                                                            self.t.value)
            # 2. Execute updates
            lr = lr_scale_decay / ( 1+ self.t.value) if decay is True else lr_scale
            
            delta_M, delta_bias = jax.lax.cond(self.t.value < HH - 1,
                                               (jnp.zeros_like(delta_M), jnp.zeros_like(delta_bias)),
                                               lambda x: x,
                                               (delta_M, delta_bias),
                                               lambda x: x)
            self.M.value -= lr * delta_M
            self.bias.value -= lr * delta_bias

            val = x - A[self.t.value] @ self.x.value - B[self.t.value] @ self.u.value
            self.w.value = jnp.vstack((self.w.value, val[None, :]))[1:]

            # 2. Update x
            self.x.value = x

            # 3. Update u
            self.u.value = -K[self.t.value] @ x
            self.u.value += jnp.tensordot(self.M.value, self.w.value[-H:], axes=([0, 2], [0, 1]))
            self.u.value += (self.bias.value * include_bias) if u is None else u

            self.t.value += 1
        return action
    
    @flax.nn.module_method
    def get_state(self, key, **kwargs):
        return self.state(key).value

In [192]:
model_def = GPC.partial(T=T, A=A, B=B, K=init_K, lr_scale=1e-4, lr_scale_decay=1e-3, H=3, HH=3, decay=True, include_bias=True)
with flax.nn.stateful() as state:
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [x0.shape])
gpc = flax.nn.Model(model_def, params)

In [193]:
with flax.nn.stateful(state) as state:
    print(gpc.get_state("M"))

[[[0. 0.]]

 [[0. 0.]]

 [[0. 0.]]]


In [105]:
%timeit x, loss = jax.lax.scan(func, (x0, state, gpc), (A, B, buzz))

Traced<ShapedArray(int32[]):JaxprTrace(level=0/0)> 3
287 ms ± 1.59 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
tod = [i + lifetime(i + 1) for i in range(T)]
tod[0] = T // 10
tod = jnp.array(np.array(tod))

In [153]:
def AdaGPC(x,
           tod,
           model,
           args,
           t,
           weights,
           alive,
           dummy,
           dummy_weight,
           learners,
           eta=1.0,
           eps=1e-5,
           sum_weight=1.0):
    state = {}
    T, A, B, K, H, HH = args["T"], args["A"], args["B"], args["K"], args["H"], args["HH"]
    
    n, m = B[0].shape
    W = 0
    
    # TODO: rewrite as vmap
    u = jnp.zeros((B[0].shape[1], 1))
    for i in alive:
        state, learner = learners[i]
        weight = weights[i]
        with flax.nn.stateful(state) as state:            
            u += weight * learner(x)
        W += weight
    Wtotal = W + dummy_weight
    u = (u + dummy_weight * K[t] @ x) / Wtotal
    
    def policy_loss(M, bias, w, t):
        y = jnp.zeros((n, 1))
        t0 = t - HH + 1
        for h in range(HH - 1):
            v = -K[t0 + h] @ y 
            v += jnp.tensordot(M, w[h : h + H], axes = ([0, 2], [0, 1])) 
            v += bias
            y = A[t0 + h] @ y + B[t0 + h] @ v + w[h + H]
        # Don't update state at the end    
        v = -K[t] @ y + jnp.tensordot(M, w[h : h + H], axes=([0, 2], [0, 1])) + bias
        return cost_fn(y, v)
    
    loss_zero = policy_loss(dummy.get_state("M"), dummy.get_state("bias"), dummy.get_state("w"), t)
    
    return u, state
        
        

In [157]:
a = set({0, 1})

In [168]:
list(a)

[0, 1]

In [159]:
args = {
    "tod": tod,
    "model": GPC
}

gpc_args = {
    "T": T,
    "A": A,
    "B": B,
    "K": init_K,
    "lr_scale": 1e-4,
    "lr_scale_decay": 1e-3,
    "H": 3,
    "HH": 3,
    "decay": True,
    "include_bias": True
}

model_def = GPC.partial(**gpc_args)
with flax.nn.stateful() as state:
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [x0.shape])
dummy = (state, flax.nn.Model(model_def, params))
with flax.nn.stateful() as state:
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [x0.shape])
learner = (state, flax.nn.Model(model_def, params))

state = {
    "t": 0,
    "weights": jnp.ones(T),
    "alive": set([0]),
    "dummy": dummy,
    "dummy_weight": 1.0,
    "learners": {0: learner}
}

In [181]:




class AdaGPC(flax.nn.Module):
    def apply(self, x, T, model, args, tod, eta=1.0, eps=1e-5, sum_weight=1.0):
        A, B, K = args["A"], args["B"], args["K"]
        
        self.t = self.state("t")
        self.u = self.state("u", (B[0].shape[1], 1), flax.nn.initializers.zeros)
        self.weights = self.state("weights", (T,), flax.nn.initializers.ones)
        self.alive = self.state("alive")
        self.dummy = self.state("dummy")
        
        learners = {0: model(x, **args)}
        print(dir(learners[0]))
        learners[0].test(self.u.value)
        
        if not self.is_initializing():
            self.t.value = 0
            self.alive.value = set([0])
#             self.learners.value = {0: model(**args)}
            self.dummy.value = model(**args)
            
            action = jnp.zeros_like(self.u.value)
            W = 0
            for i in self.alive.value:
                print(i)

model_def = AdaGPC.partial(T=T, model=GPC, tod=tod, args=gpc_args)
with flax.nn.stateful() as state:
    _, params = model_def.init_by_shape(jax.random.PRNGKey(0), [x0.shape])
adagpc = flax.nn.Model(model_def, params)

['__abs__', '__add__', '__and__', '__array__', '__array_priority__', '__bool__', '__class__', '__complex__', '__copy__', '__deepcopy__', '__delattr__', '__dir__', '__div__', '__divmod__', '__doc__', '__eq__', '__float__', '__floordiv__', '__format__', '__ge__', '__getattr__', '__getattribute__', '__getitem__', '__gt__', '__hash__', '__hex__', '__init__', '__init_subclass__', '__int__', '__invert__', '__iter__', '__le__', '__len__', '__long__', '__lshift__', '__lt__', '__matmul__', '__mod__', '__module__', '__mul__', '__ne__', '__neg__', '__new__', '__nonzero__', '__oct__', '__or__', '__pos__', '__pow__', '__radd__', '__rand__', '__rdiv__', '__rdivmod__', '__reduce__', '__reduce_ex__', '__repr__', '__rfloordiv__', '__rmatmul__', '__rmod__', '__rmul__', '__ror__', '__rpow__', '__rshift__', '__rsub__', '__rtruediv__', '__rxor__', '__setattr__', '__setitem__', '__sizeof__', '__slots__', '__str__', '__sub__', '__subclasshook__', '__truediv__', '__weakref__', '__xor__', '_contents', '_trace'

AttributeError: 'ShapedArray' object has no attribute 'test'

In [139]:
x, loss = jax.lax.scan(func, (x0, state, adagpc), (A, B, buzz))

ValueError: No module named GPC_0 was created during initialization.