In [45]:
# Imports

import jax
import jax.numpy as jnp
import numpy as onp
import immrax as irx
import equinox as eqx
import equinox.nn as nn
from pathlib import Path
import optax
from functools import partial 
import matplotlib.pyplot as plt
import matplotlib.animation as animation
from matplotlib.patches import Rectangle
from immutabledict import immutabledict
import jax.experimental.compilation_cache.compilation_cache as cc
import os
from jax.scipy.linalg import block_diag
from time import time

device = 'gpu'

# if device == 'gpu' :
#     cc.initialize_cache('cache')

def jit (f, *args, **kwargs) :
    kwargs.setdefault('backend', device)
    return eqx.filter_jit(f, *args, **kwargs)


In [46]:
U_LIM = 10.

class Platoon(irx.OpenLoopSystem) :
    """A platoon of N vehicles

    The state is [x1, v1, x2, v2, ..., xN, vN]
    """

    def __init__ (self, N) :
        self.evolution = 'continuous'
        self.xlen = 2*N
        self.N = N

    def f (self, t, x, u, w) :
        xdot = jnp.zeros(2*self.N)
        xdot = xdot.at[0::2].set(x[1::2])
        xdot = xdot.at[1::2].set(U_LIM*jnp.tanh(u/U_LIM)*(1+w))
        return xdot

# f for one vehicle
def f_veh (t, x, u, w) :
    return jnp.array([x[1], U_LIM*jnp.tanh(u[0]/U_LIM)*(1+w[0])])

veh_mjacM = irx.mjacM(f_veh)

In [47]:
N = 3*9 + 1
sys = Platoon(N)

print(f'Number of vehicles: {N}')
print(f'Number of states: {sys.xlen}')

Hblk = jnp.array([
    [1., 0.],
    [0., 1.],
    [1., 1.]
])
H = jnp.kron(jnp.eye(N), Hblk)
yblk = 0.1*jnp.array([1., 1., 0.8])
# ypre = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.arange(N)**2+1, yblk))
_ypre = jnp.empty(N)
_ypre = _ypre.at[::3].set(1.); _ypre = _ypre.at[1::3].set(3.); _ypre = _ypre.at[2::3].set(9.)
ypre = irx.icentpert(jnp.zeros(len(H)), jnp.kron(_ypre, yblk))
# ypre = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.array([1., 3., 1.]), yblk))
# y = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.array([1., 1.5, 1.]), yblk))
# y = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.array([1., 3.]), yblk))
# ypre = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.array([1., 3., 20., 3., 1.]), yblk))
# y = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.ones(N), yblk))
# y = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.array([1., 3., 9.]), yblk))
# y = irx.icentpert(jnp.zeros(len(H)), jnp.kron(jnp.ones(N), yblk))
trainable = jnp.kron(jnp.ones(N), jnp.array([0,0,0])).astype(bool)
# print(trainable)
ywhere = jnp.where(trainable)
w = irx.icentpert(jnp.zeros(N), 0.1)

print(f'Number of lifted states: {len(H)}')

Number of vehicles: 28
Number of states: 56
Number of lifted states: 84


In [48]:
# Setup pseudoinverse, get null space, and make IH function.

Hblkdag = jnp.linalg.pinv(Hblk)
Hdag = jnp.kron(jnp.eye(N), Hblkdag)
NHblk = irx.utils.null_space(Hblk.T)
NH = jnp.kron(jnp.eye(N), NHblk)

print(jnp.all(jnp.isclose(NH.T @ H, 0., atol=1e-5)))

IH = jit(irx.utils.I_refine(NH.T))
# ypre = IH(ypre)

True


In [49]:
# Setup neural network controller

os.system(f'mkdir -p {N}-platoon')
arch = '6 32 relu 32 relu 32 relu 1'
os.system('echo ' + arch + f' > {N}-platoon/arch.txt')

class PlatoonControl (irx.Control, eqx.Module) :
    net: list
    out_len:int = eqx.field(static=True)

    def __init__ (self, key=jax.random.key(0)) :
        self.net = irx.NeuralNetwork(f'{N}-platoon', False, key=key)
        self.out_len = N
    
    def __call__ (self, x) :
        x = x.reshape((N, 2))

        # Setup relative inputs for the network for vmap
        X = jnp.zeros((N, 6))
        X = X.at[1:,2:4].set(x[:-1] - x[1:])  # relative position to the previous vehicle
        X = X.at[:-1,4:6].set(x[:-1] - x[1:])  # relative position to the next vehicle
        X = X.at[::3, :2].set(x[::3])  # absolute position for every 3-rd vehicle
        X = X.at[::3, 2:].set(0.) 
        X = X.at[3::3, 2:4].set(x[:-1:3] - x[3::3])  # relative position to the previous third vehicle
        X = X.at[:-1:3, 4:6].set(x[:-1:3] - x[3::3])  # relative position to the next third vehicle
       
        # vmap over neural net
        def apply_net (X) :
            return self.net(X)[0]
        return jax.vmap(apply_net)(X)

    def u (self, t, x) :
        return self(x)

    def save (self) :
        self.net.save()

net = PlatoonControl()
print(net(jnp.zeros(2*N)))


[-0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878
 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878
 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878
 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878 -0.06237878
 -0.06237878 -0.06237878 -0.06237878 -0.06237878]


In [50]:
# Setup the lifted embedding system

# Single lifted vehicle dynamics
def g_veh (t, x, u, w, *, Hblkp) :
    return Hblk @ f_veh(t, Hblkp@x, u, w)

g_veh_mjacM = irx.jacM(g_veh)

def wrapped_g_veh_mjacM (t, x, u, w, Hblkp) :
    return g_veh_mjacM(t, x, u, w, Hblkp=Hblkp)

def build_lifted_clsys (net, eta) :
    Hp = Hdag + eta@NH.T
    lifted_sys = irx.LiftedSystem(sys, H, Hp)

    def lifted_net (y) :
        return net(Hp @ y)
    lifted_net.u = lambda t, y : lifted_net(y)

    lifted_net.out_len = net.out_len
    return irx.NNCSystem(lifted_sys, lifted_net)


def build_lifted_embsys (net, eta) :
    Hp = Hdag + eta@NH.T
    lifted_sys = irx.LiftedSystem(sys, H, Hp)

    Hblkps = jnp.array([Hblkdag + eta[i:i+2,(i,)] @ NHblk.T for i in range(0, 2*N, 2)])

    # Block diagonal mjacM for the platoon, more efficient to build block diagonally
    def lifted_mjacM (t, y, u, w, **kwargs) :
        # vmap over vehicles using g_veh_mjacM and Sblkps
        inputs = (t, y.reshape((N, 3)), u.reshape((N, 1)), w.reshape((N, 1)), Hblkps)
        # mjacMs = jax.vmap(wrapped_g_veh_mjacM, in_axes=(None, 0, 0, 0, 0))(*inputs)[0]
        mjacMs = jax.vmap(wrapped_g_veh_mjacM, in_axes=(None, 0, 0, 0, 0))(*inputs)
        Mt = jnp.zeros((3*N, 1))
        Mx = irx.natif(block_diag)(*[mjacMs[1][i] for i in range(N)])
        Mu = irx.natif(block_diag)(*[mjacMs[2][i] for i in range(N)])
        Mw = irx.natif(block_diag)(*[mjacMs[3][i] for i in range(N)])
        return [[Mt, Mx, Mu, Mw]]

    def lifted_net (y) :
        return net(Hp @ y)

    lifted_net.out_len = net.out_len
    lifted_net.u = lambda t, y : lifted_net(y)
    lifted_clsys = irx.NNCSystem(lifted_sys, lifted_net)
    lifted_embsys = irx.NNCEmbeddingSystem(lifted_clsys, 'crown', 'local', 'local', lifted_mjacM)
    return lifted_embsys

eta = jnp.zeros((2*N, len(NH.T)))
oly = ypre.upper[ywhere]

In [51]:

# Fake permutations and corners for the lifted embedding system
permutations = irx.standard_permutation(1 + len(H) + N + N)
corners = irx.bot_corner(1 + len(H) + N + N)

# Lifted embedding system evaluation
def EH (net:irx.NeuralNetwork, eta:jax.Array=eta, y=ypre) :
    lifted_embsys = build_lifted_embsys(net, eta)
    return lifted_embsys.E(irx.interval([0.]), irx.i2ut(y), w,
        permutations=permutations, corners=corners, refine=IH)

def relu_eps (x, eps) :
    return jax.nn.relu(x + eps)

# Loss from the lifted embedding system evaluation
def LH (net:irx.NeuralNetwork, eta:jax.Array=eta, y=ypre, epsl:float=0.02, epsu:float=0.02) :
    E = EH(net, eta, y)
    return jnp.sum(jax.vmap(partial(relu_eps, eps=epsl))(-E[:len(H)])) \
         + jnp.sum(jax.vmap(partial(relu_eps, eps=epsu))( E[len(H):])) 

def oly_to_y (oly) :
    return irx.interval(ypre.lower.at[trainable].set(-oly), ypre.upper.at[trainable].set(oly))

def loss(params) :
    net, eta, oly = params
    return LH(net, eta, oly_to_y(oly)) + 100.*jnp.sum(jax.vmap(partial(relu_eps, eps=0.05))(-oly)) 

In [52]:
optim = optax.adam(0.001)

# A step of the optimizer
@jit
def make_step (params, opt_state) :
    loss_value, grads = eqx.filter_value_and_grad(loss)(params)
    updates, opt_state = optim.update(grads, opt_state, params)
    params = eqx.apply_updates(params, updates)
    net, eta, oly = params
    return (params, opt_state, loss_value, EH(net, eta, oly_to_y(oly)))

params = (net, eta, oly)

t0 = time()
make_step(params, optim.init(eqx.filter(params, eqx.is_array)))
tf = time()
print(f'Time to JIT make_step: {tf-t0}')

2024-05-22 10:21:39.628425: E external/xla/xla/service/slow_operation_alarm.cc:65] 
********************************
[Compiling module jit_make_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************
2024-05-22 10:21:45.380135: E external/xla/xla/service/slow_operation_alarm.cc:133] The operation took 2m5.751888362s

********************************
[Compiling module jit_make_step] Very slow compile? If you want to file a bug, run with envvar XLA_FLAGS=--xla_dump_to=/tmp/foo and attach the results.
********************************


Time to JIT make_step: 477.7850890159607


In [53]:
def train(params, optim, steps, minsteps, print_every=1) -> irx.NeuralNetwork :
    opt_state = optim.init(eqx.filter(params, eqx.is_array))

    for step in range(steps) :
        params, opt_state, train_loss, EHval = make_step(params, opt_state)
        if (step % print_every) == 0 or (step == steps - 1) :
            net, eta, oly = params
            net.save()
            print(
                f'{step=}, train_loss={train_loss.item()}, '
                f'\nEHl={EHval[:len(H)]}, \nEHu={EHval[len(H):]}'
                # f'\ny={oly_to_y(oly)}'
                # f'\neta={eta} \n'
            )
        if (jnp.all(EHval[:len(H)] >= 0) 
            and jnp.all(EHval[len(H):] <= 0)
            and step >= minsteps) :
            print('EH constraints satisfied, stopping training')
            print(
                f'{step=}, train_loss={train_loss.item()}, '
                f'\nEHl={EHval[:len(H)]}, \nEHu={EHval[len(H):]}'
                f'\neta={eta} \n'
                f'\ny={oly_to_y(oly)}'
            )
            return params
    
    return params

t0 = time()
net, eta, oly = train((net, eta, oly), optim, 1000000, 100, 100)
tf = time()
net.save()
print(f'Finished training in {tf - t0} seconds.')
y = oly_to_y(oly)
onp.save(f'{N}-platoon/H.npy', H)
onp.save(f'{N}-platoon/ylyu.npy', irx.i2lu(y))

Saving model to 28-platoon/model.eqx... done.
step=0, train_loss=43.341426849365234, 
EHl=[ 0.01944574 -0.05904369 -0.16128539  0.0606235  -0.23480365 -0.57062554
  0.18016829 -0.08125772 -1.0651081   0.01909932 -0.06552102 -0.16911492
  0.06025036 -0.2354295  -0.57128537  0.1801683  -0.08127012 -1.0651878
  0.01903004 -0.06567524 -0.16936047  0.06025036 -0.23546375 -0.57131946
  0.18016827 -0.08132209 -1.0652773   0.01900046 -0.06568238 -0.16943778
  0.06025036 -0.235473   -0.571329    0.18016827 -0.08132208 -1.0652771
  0.01900047 -0.06568645 -0.16944396  0.06025036 -0.23540728 -0.5712636
  0.18034627 -0.08129533 -1.065523    0.01903005 -0.0656826  -0.16943896
  0.0606235  -0.23520131 -0.57081586  0.18034628 -0.08125319 -1.0654159
  0.01903005 -0.06568253 -0.16943838  0.06062349 -0.23520131 -0.5708159
  0.1803463  -0.08125319 -1.0654159   0.01903004 -0.06566624 -0.1693489
  0.0606235  -0.23520131 -0.5708159   0.1803463  -0.08124074 -1.0653808
  0.01909933 -0.06554879 -0.16912588  0.0

In [54]:
%matplotlib widget
if N == 4 :
    from pypoman import compute_polytope_vertices, plot_polygon

    plt.rcParams.update({
        "text.usetex": True,
        "font.family": "Helvetica",
        "font.size": 14
    })

    fig = plt.figure(figsize=(4,4), dpi=100)
    ax = fig.add_subplot(111)

    ll = [1., 3., 9, 1.]
    cc = ['tab:red', 'tab:blue', 'tab:green', 'tab:orange']

    vv = []

    for i, l in enumerate(ll) :
        vv.append(compute_polytope_vertices(
            onp.vstack((-Hblk, Hblk)), onp.concatenate((-y.lower[i*3:i*3+3], y.upper[i*3:i*3+3]))))
        plot_polygon(vv[-1], alpha=1., fill=False, linewidth=2., color=cc[i])
    # ax.plot(xx[:,i*2], xx[:,i*2+1], label=f'Vehicle {i+1}')

    # Get a random vertex from vv
    vert = [0,-3,5,3]
    print(vv)
    x0 = jnp.vstack([vv[i][vert[i]] for i in range(4)]).flatten()
    clsys = irx.ControlledSystem(sys, net)
    w_map = lambda t,x : jnp.zeros(1)
    traj = clsys.compute_trajectory(0., 10., x0, (w_map,), dt=0.01)
    tfinite = jnp.where(jnp.isfinite(traj.ts))
    xx = traj.ys[tfinite]
    for i in range(N) :
        ax.plot(xx[:,i*2], xx[:,i*2+1], label=f'Vehicle {i+1}', color=cc[i])
    ax.legend()

    fig.tight_layout()
    fig.savefig('4-platoon.pdf')
    plt.show()