https://rexlab.ri.cmu.edu/papers/iLQR_Tutorial.pdf

https://bjack205.github.io/papers/AL_iLQR_Tutorial.pdf

In [None]:
import jax.numpy as jnp
from jax import grad, jit, vmap
from jax import random
from jax import jacfwd, jacrev
import jax

In [None]:
class ILQR():
  def __init__(self, cost, dynamics, x, u, xgoal, steps):
    self.x = x
    self.u = u
    self.xgoal = xgoal
    self.cost = cost
    self.dynamics = dynamics
    self.Qu = None
    self.Quu = None
    self.Qux = None
    self.A = None
    self.B = None
    self.s = None
    self.S = None
    self.d = jnp.zeros((steps, self.u.shape[1], 1))
    self.K = jnp.zeros((steps, self.u.shape[1], self.x.shape[1]))
    self.deltaV = jnp.zeros((steps, 1, 2))
    self.steps = steps
    self.p = 1e-5
    #self.costs = jnp.array([0, self.getTotalCost()])

  def getlx(self, step):
    return jacfwd(lambda x: self.cost(x, self.u[step]))(self.x[step]).reshape(self.x.shape[1], 1)

  def getlu(self, step):
    return jacfwd(lambda u: self.cost(self.x[step], u))(self.u[step]).reshape(self.u.shape[1], 1)

  def getlxx(self, step):
    return jacfwd(jacfwd(lambda x: self.cost(x, self.u[step])))(self.x[step]).reshape(self.x.shape[1], self.x.shape[1])

  def getluu(self, step):
    return jacfwd(jacfwd(lambda u: self.cost(self.x[step], u)))(self.u[step]).reshape(self.u.shape[1], self.u.shape[1])

  def getlux(self, step):
    return jacfwd(lambda x: jacfwd(lambda u: self.cost(x, u))(self.u[step]))(self.x[step]).reshape(self.u.shape[1], self.x.shape[1])

  def getA(self, step):
    self.A = jacfwd(lambda x: self.dynamics(x, self.u[step]))(self.x[step]).reshape(self.x.shape[1], self.x.shape[1])

  def getB(self, step):
    self.B = jacfwd(lambda u: self.dynamics(self.x[step], u))(self.u[step]).reshape(self.x.shape[1], self.u.shape[1])

  def getQx(self, step):
    #only used once so don't have to store
    return self.getlx(step) + self.A.T@self.s

  def getQu(self, step):
    self.Qu = self.getlu(step) + self.B.T@self.s
    return self.Qu

  def getQxx(self, step):
    #only used once so don't have to store
    return self.getlxx(step) + self.A.T@self.S@self.A

  def getQuu(self, step):
    self.Quu = self.getluu(step) + self.B.T@self.S@self.B
    return self.Quu

  def getQux(self, step):
    self.Qux = self.getlux(step) + self.B.T@self.S@self.A
    return self.Qux

  def getd(self, step):
    self.d = self.d.at[step, :, :].set(-jnp.linalg.inv(self.Quu + self.p*jnp.eye(self.u.shape[1]))@self.getQu(step))
    return self.d[step]

  def getK(self, step):
    self.K = self.K.at[step, :, :].set(-jnp.linalg.inv(self.getQuu(step) + self.p*jnp.eye(self.u.shape[1]))@self.getQux(step))
    return self.K[step]

  def getDeltaV(self, step):
    self.deltaV = self.deltaV.at[step, :, :].set(jnp.array([[self.d[step].T@self.Qu, 1/2*self.d[step].T@self.Quu@self.d[step]]]).reshape(1,2))

  def getSumDeltaV(self, a):
    aArray = jnp.array([[a, a**2]]).T
    V = 0
    for i in range(self.deltaV.shape[0]):
      V += self.deltaV[i]@aArray
    return V


  def getTotalCost(self):
    cost = 0
    for i in range(self.steps):
      cost += self.cost(self.x[i], self.u[i])
    return cost

  def initializes(self):
    self.s = self.getlxx(self.steps - 1)@(self.x[self.steps - 1]-self.xgoal)

  def initializeS(self):
    self.S = self.getlxx(self.steps - 1)

  def gets(self, step):
    self.getA(step)
    self.getB(step)

    self.s = self.getQx(step) + self.getK(step).T@self.Quu@self.getd(step) + self.K[step].T@self.Qu + self.Qux.T@self.d[step]

  def getS(self, step):
    self.S = self.getQxx(step) + self.K[step].T@self.Quu@self.K[step] + self.K[step].T@self.Qux + self.Qux.T@self.K[step]

  def rollout(self):
    x = self.x[0]
    for i in range(0, self.x.shape[0]):
      self.u = self.u.at[i, :, :].set(self.u[i] + self.K[i]@(x - self.x[i]) + self.d[i])
      self.x = self.x.at[i, :, :].set(x)
      x = self.dynamics(x, self.u[i])
    self.x = self.x.at[self.x.shape[0]-1, :, :].set(x)

  def run(self):
    self.initializes()
    self.initializeS()
    for i in range(self.steps, -1, -1):
      self.gets(i)
      self.getS(i)
      #self.getDeltaV(i)
    self.rollout()





dt = .05

A = jnp.array([[1, 0, dt, 0],
              [0, 1, 0, dt],
              [0, 0, (1-dt), 0],
              [0, 0, 0, (1-dt)]], dtype="float32")
B = jnp.array([[0, 0],
              [0, 0],
              [200*dt, 0],
              [0, 200*dt]], dtype="float32")

def dynamics(x, u):
  A = jnp.array([[1, 0, dt, 0],
              [0, 1, 0, dt],
              [0, 0, (1-dt), 0],
              [0, 0, 0, (1-dt)]], dtype="float32")
  B = jnp.array([[0, 0],
              [0, 0],
              [200*dt, 0],
              [0, 200*dt]], dtype="float32")

  return A@x + B@u

steps = 100

u = jnp.zeros((steps, 2, 1), dtype="float32")
x = jnp.zeros((steps, 4, 1), dtype="float32")
x = x.at[0, 0, 0].set(5)
x = x.at[0, 1, 0].set(5)

xgoal = jnp.array([[0, 0, 0, 0]]).T

R = jnp.eye(u.shape[1]) * .01

Q = jnp.eye(x.shape[1]) * 100

def cost(x, u):
  return x.T@Q@x + u.T@R@u

ilqr = ILQR(cost, dynamics, x, u, xgoal, steps)

In [None]:
ilqr.run()

In [None]:
print(ilqr.x)

[[[ 5.00000000e+00]
  [ 5.00000000e+00]
  [ 0.00000000e+00]
  [ 0.00000000e+00]]

 [[ 5.00000000e+00]
  [ 5.00000000e+00]
  [-4.87612534e+00]
  [-4.87612534e+00]]

 [[ 4.75619364e+00]
  [ 4.75619364e+00]
  [-4.63831997e+00]
  [-4.63831997e+00]]

 [[ 4.52427769e+00]
  [ 4.52427769e+00]
  [-4.41210604e+00]
  [-4.41210604e+00]]

 [[ 4.30367231e+00]
  [ 4.30367231e+00]
  [-4.19692230e+00]
  [-4.19692230e+00]]

 [[ 4.09382629e+00]
  [ 4.09382629e+00]
  [-3.99223042e+00]
  [-3.99223042e+00]]

 [[ 3.89421487e+00]
  [ 3.89421487e+00]
  [-3.79751992e+00]
  [-3.79751992e+00]]

 [[ 3.70433879e+00]
  [ 3.70433879e+00]
  [-3.61230302e+00]
  [-3.61230302e+00]]

 [[ 3.52372360e+00]
  [ 3.52372360e+00]
  [-3.43611670e+00]
  [-3.43611670e+00]]

 [[ 3.35191774e+00]
  [ 3.35191774e+00]
  [-3.26852059e+00]
  [-3.26852059e+00]]

 [[ 3.18849182e+00]
  [ 3.18849182e+00]
  [-3.10909653e+00]
  [-3.10909653e+00]]

 [[ 3.03303695e+00]
  [ 3.03303695e+00]
  [-2.95744467e+00]
  [-2.95744467e+00]]

 [[ 2.88516474e+