In [4]:
import jax
ctx = jax.default_device(jax.devices("cpu")[0])
ctx.__enter__()

In [29]:
import jax.numpy as jp
from jax.tree_util import tree_map
import numpy as np

def lqr_continuous_time_infinite_horizon(A, B, Q, R, N = None):
  # Take the last dimension, in case we try to do some kind of broadcasting
  # thing in the future.
  x_dim = A.shape[-1]
#   Q.at[1,0].set(Q[0,1])
#   Q[1,0]=Q[0,1]
  Q  = Q.at[1,0].set(2*(Q[1,0]))
  Q  = Q.at[0,1].set(2*(Q[0,1]))
  Q = (Q.T+Q)/2

  # See https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_continuous-time_LQR.
  A1 = A - B @ jp.linalg.solve(R, N.T)
  Q1 = Q - N @ jp.linalg.solve(R, N.T)

  H = jp.block([[A1, -B @ jnp.linalg.inv(R)@B.T], [-Q1, -A1.T]])
  eigvals, eigvectors = eig(H)
  argsort = jp.argsort(eigvals)
  ix = argsort[:x_dim]
  U = eigvectors[:, ix]
  P = U[x_dim:, :] @ jp.linalg.inv(U[:x_dim, :])
 
  P = jp.real(P)
  K = jp.linalg.inv(R)@B.T @ P
  
  return K
  

# def _test_lqr(n):
#   import control
#   from jax.tree_util import tree_map

#   A = jp.eye(n)
#   B = jp.eye(n)
#   Q = jp.eye(n)
#   R = jp.eye(n)
#   N = jp.zeros((n, n))

#   actual = lqr_continuous_time_infinite_horizon(A, B, Q, R, N)
#   expected = control.lqr(A, B, Q, R, N)
#   print(tree_map(jp.allclose, actual, expected))
#   assert tree_map(jp.allclose, actual, expected)

# if __name__ == "__main__":
#   _test_lqr(10)

In [6]:
import jax.numpy as jnp
import jax.lax.linalg as lax_linalg
from jax import custom_jvp
from functools import partial

from jax import lax
from jax.numpy.linalg import solve
@custom_jvp
def eig(a):
    w, vl, vr = jax.numpy.linalg.eig(a)
    return w, vr


@eig.defjvp
def eig_jvp_rule(primals, tangents):
    a, = primals
    da, = tangents

    w, v = eig(a)

    eye = jnp.eye(a.shape[-1], dtype=a.dtype)
    # carefully build reciprocal delta-eigenvalue matrix, avoiding NaNs.
    Fmat = (jnp.reciprocal(eye + w[..., jnp.newaxis, :] - w[..., jnp.newaxis])
            - eye)
    dot = partial(lax.dot if a.ndim == 2 else lax.batch_matmul,
                  precision=lax.Precision.HIGHEST)
    vinv_da_v = dot(solve(v, da), v)
    du = dot(v, jnp.multiply(Fmat, vinv_da_v))
    corrections = (jnp.conj(v) * du).sum(-2, keepdims=True)
    dv = du - v * corrections
    dw = jnp.diagonal(vinv_da_v, axis1=-2, axis2=-1)
    return (w, v), (dw, dv)

In [32]:
# Define system matrices
# A = jp.array([[0.0, 1.0], [-1.0, -1.0]])
# B = jp.array([[0.0], [1.0]])
# Q = jp.array([[1.0, 0.0], [0.0, 1.0]])
# R = jp.array([[1.0]])
# N = jp.zeros((2,1))
# n=2
# A = jp.eye(n)
# B = jp.eye(n)
# Q = jp.eye(n)
# R = jp.eye(n)
# N = jp.zeros((n, n))

# A = jnp.array([[0, 1], [0, 0]])
# B = jnp.array([[0], [1]])
# Q = jnp.eye(2)
# R = jnp.eye(1)
# N = jp.zeros((2,1))

Av = jnp.array([[0.0000,-0.0001,0.0002],[0.0030,-0.2047,-0.0202],[0.0012,-0.0002,-0.0721]])
Bv = jnp.eye(3)
Q = jnp.diag(np.array([0.01, 0.01, 1000.0, 0.0, 0.0, 0.0]))
R = jnp.diag(np.array([1.0, 1.0, 1.0]))
Cv = jnp.eye(3)
Dv = jnp.zeros((3, 3))
A = jnp.vstack(
    (jnp.hstack((jnp.zeros((3, 3)), Cv)), jnp.hstack((jnp.zeros((3, 3)), Av)))
)
B = jnp.vstack((Dv, Bv))
# P = lqr_continuous_time_infinite_horizon(A, B, Q, R)
# print("LQR Solution (P):")
# print(P)
# analytical_grad_P = jax.jacobian(lambda R: lqr_continuous_time_infinite_horizon(A, B, Q, R))(R)
# analytical_grad_Q = jax.jacobian(lambda Q: lqr_continuous_time_infinite_horizon(A, B, Q, R))(Q)
from control.matlab import *
Ki, P, CLP = lqr(A, B, Q, R)

print("LQR Solution of python lqr library (P):")
print(Ki)
# print(analytical_grad_P.reshape((-1)))


print("---------------------dP/dR-----------------------")
print("------(1) Finite Difference Approximation---------")
A = jnp.array([[0, 1], [0, 0]])
B = jnp.array([[0], [1]])
Q = jnp.eye(2)
R = jnp.eye(1) + 1e-4
N = jp.zeros((2,1))
P_1 = lqr_continuous_time_infinite_horizon(A, B, Q, R)
print(((P_1-P)/(1e-4)).reshape((-1)))
print("----------(2) JAX Auto Differentiation------------")
print(analytical_grad_P.reshape((-1)))
print("---------------------dP/dQ------------------------")
print("------(1) Finite Difference Approximation---------")
A = jnp.array([[0, 1], [0, 0]])
B = jnp.array([[0], [1]])
Q = jnp.array([[1+ 1e-4,0], [0, 1]])
R = jnp.eye(1)
N = jp.zeros((2,1))
P_1 = lqr(A, B, Q, R)[1]
print(((P_1-P)/(1e-4)).reshape((-1)))

A = jnp.array([[0, 1], [0, 0]])
B = jnp.array([[0], [1]])
Q = jnp.array([[1,1e-4], [1e-4, 1]])
R = jnp.eye(1)
N = jp.zeros((2,1))
P_1 = lqr(A, B, Q, R)[1]
print(((P_1-P)/(1e-4)).reshape((-1)))


A = jnp.array([[0, 1], [0, 0]])
B = jnp.array([[0], [1]])
Q = jnp.array([[1,1e-4], [1e-4, 1]])
R = jnp.eye(1)
N = jp.zeros((2,1))
P_1 = lqr(A, B, Q, R)[1]
print(((P_1-P)/(1e-4)).reshape((-1)))

A = jnp.array([[0, 1], [0, 0]])
B = jnp.array([[0], [1]])
Q = jnp.array([[1,0], [0, 1+1e-4]])
R = jnp.eye(1)
N = jp.zeros((2,1))
P_1 = lqr(A, B, Q, R)[1]
print(((P_1-P)/(1e-4)).reshape((-1)))

print("----------(2) JAX Auto Differentiation------------")
# print(analytical_grad_Q.reshape((4,4)).T)
print(analytical_grad_Q[:,:, 0, 0].reshape((-1)))
print(analytical_grad_Q[:,:, 0, 1].reshape((-1)))
print(analytical_grad_Q[:,:, 1, 0].reshape((-1)))
print(analytical_grad_Q[:,:, 1, 1].reshape((-1)))


LQR Solution of python lqr library (P):
[[ 9.99994540e-02  3.30118044e-04  5.16472351e-04  4.47218972e-01
   8.69427500e-04  1.19560881e-03]
 [-3.30118323e-04  9.99994525e-02  5.48210595e-03  8.69427500e-04
   2.87133415e-01 -2.26469153e-04]
 [-1.57599102e-06 -1.73412383e-05  3.16227761e+01  1.19560881e-03
  -2.26469153e-04  7.88093457e+00]]
---------------------dP/dR-----------------------
------(1) Finite Difference Approximation---------


AttributeError: 'NoneType' object has no attribute 'T'

In [266]:
"""Defines several utility functions.

Copyright (c) Meta Platforms, Inc. and affiliates.
"""

from typing import Tuple

import jax
import jax.numpy as jnp
import numpy as onp
from jax.experimental import host_callback

EPS_EIG = 1e-6


def diag(x: jnp.ndarray) -> jnp.ndarray:
    """A batch-compatible version of `numpy.diag`."""
    shape = x.shape + (x.shape[-1],)
    y = jnp.zeros(shape, x.dtype)
    i = jnp.arange(x.shape[-1])
    return y.at[..., i, i].set(x)


def angular_frequency_for_wavelength(wavelength: jnp.ndarray) -> jnp.ndarray:
    """Returns the angular frequency for the specified wavelength."""
    return 2 * jnp.pi / wavelength  # Since by our convention c == 1.


def matrix_adjoint(x: jnp.ndarray) -> jnp.ndarray:
    """Computes the adjoint for a batch of matrices."""
    axes = tuple(range(x.ndim - 2)) + (x.ndim - 1, x.ndim - 2)
    return jnp.conj(jnp.transpose(x, axes=axes))


def batch_compatible_shapes(*shapes: Tuple[int, ...]) -> bool:
    """Returns `True` if all the shapes are batch-compatible."""
    max_dims = max([len(s) for s in shapes])
    shapes = tuple([(1,) * (max_dims - len(s)) + s for s in shapes])
    max_shape = [max(dim_shapes) for dim_shapes in zip(*shapes)]
    for shape in shapes:
        if any([a not in (1, b) for a, b in zip(shape, max_shape)]):
            return False
    return True


def atleast_nd(x: jnp.ndarray, n: int) -> jnp.ndarray:
    """Adds leading dimensions to `x`, ensuring that it is at least n-dimensional."""
    dims_to_add = tuple(range(max(0, n - x.ndim)))
    return jnp.expand_dims(x, axis=dims_to_add)


def absolute_axes(axes: Tuple[int, ...], ndim: int) -> Tuple[int, ...]:
    """Returns the absolute axes for given relative axes and number of array dimensions."""
    if not all(a in list(range(-ndim, ndim)) for a in axes):
        raise ValueError(
            f"All elements of `axes` must be in the range ({ndim}, {ndim - 1}) "
            f"but got {axes}."
        )
    absolute_axes = tuple([d % ndim for d in axes])
    if len(absolute_axes) != len(set(absolute_axes)):
        raise ValueError(
            f"Found duplicates in `axes`; computed absolute axes are {absolute_axes}."
        )
    return absolute_axes


def interpolate_permittivity(
    permittivity_solid: jnp.ndarray,
    permittivity_void: jnp.ndarray,
    density: jnp.ndarray,
) -> jnp.ndarray:
    """Interpolates the permittivity with a scheme that avoids zero crossings.

    The interpolation uses the scheme introduced in [2019 Christiansen], which avoids
    zero crossings that can occur with metals or lossy materials having a negative
    real component of the permittivity. https://doi.org/10.1016/j.cma.2018.08.034

    Args:
        permittivity_solid: The permittivity of solid regions.
        permittivity_void: The permittivity of void regions.
        density: The density, specifying which locations are solid and which are void.

    Returns:
        The interpolated permittivity.
    """
    n_solid = jnp.real(jnp.sqrt(permittivity_solid))
    k_solid = jnp.imag(jnp.sqrt(permittivity_solid))
    n_void = jnp.real(jnp.sqrt(permittivity_void))
    k_void = jnp.imag(jnp.sqrt(permittivity_void))
    n = density * n_solid + (1 - density) * n_void
    k = density * k_solid + (1 - density) * k_void
    return (n + 1j * k) ** 2


# -----------------------------------------------------------------------------
# Functions related to a generalized eigensolve with custom vjp rule.
# -----------------------------------------------------------------------------


@jax.custom_vjp
def eig(matrix: jnp.ndarray, eps: float = EPS_EIG) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps `jnp.linalg.eig` in a jit-compatible, differentiable manner.

    The custom vjp allows gradients with resepct to the eigenvectors, unlike the
    standard jax implementation of `eig`. We use an expression for the gradient
    given in [2019 Boeddeker] along with a regularization scheme used in [2021
    Colburn]. The method effectively applies a Lorentzian broadening to a term
    containing the inverse difference of eigenvalues.

    [2019 Boeddeker] https://arxiv.org/abs/1701.00392
    [2021 Coluburn] https://www.nature.com/articles/s42005-021-00568-6

    Args:
        matrix: The matrix for which eigenvalues and eigenvectors are sought.
        eps: Parameter which determines the degree of broadening.

    Returns:
        The eigenvalues and eigenvectors.
    """
    del eps
    return _eig_host(matrix)


def _eig_host(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
    """Wraps jnp.linalg.eig so that it can be jit-ed on a machine with GPUs."""
    eigenvalues_shape = jax.ShapeDtypeStruct(matrix.shape[:-1], complex)
    eigenvectors_shape = jax.ShapeDtypeStruct(matrix.shape, complex)

    def _eig_cpu(matrix: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]:
        # We force this computation to be performed on the cpu by jit-ing and
        # explicitly specifying the device.
        with jax.default_device(jax.devices("cpu")[0]):
            return jax.jit(jnp.linalg.eig)(matrix)

    return host_callback.call(
        _eig_cpu,
        matrix.astype(complex),
        result_shape=(eigenvalues_shape, eigenvectors_shape),
    )


def _eig_fwd(
    matrix: jnp.ndarray,
    eps: float,
) -> Tuple[Tuple[jnp.ndarray, jnp.ndarray], Tuple[jnp.ndarray, jnp.ndarray, float]]:
    """Implements the forward calculation for `eig`."""
    eigenvalues, eigenvectors = _eig_host(matrix)
    return (eigenvalues, eigenvectors), (eigenvalues, eigenvectors, eps)


def _eig_bwd(
    res: Tuple[jnp.ndarray, jnp.ndarray, float],
    grads: Tuple[jnp.ndarray, jnp.ndarray],
) -> Tuple[jnp.ndarray, None]:
    """Implements the backward calculation for `eig`."""
    eigenvalues, eigenvectors, eps = res
    grad_eigenvalues, grad_eigenvectors = grads

    # Compute the F-matrix, from equation 5 of [2021 Colburn]. This applies a
    # Lorentzian broadening to the matrix `f = 1 / (eigenvalues[i] - eigenvalues[j])`.
    eigenvalues_i = eigenvalues[..., jnp.newaxis, :]
    eigenvalues_j = eigenvalues[..., :, jnp.newaxis]
    f_broadened = (eigenvalues_i - eigenvalues_j) / (
        (eigenvalues_i - eigenvalues_j) ** 2 + eps
    )

    # Manually set the diagonal elements to zero, as we do not use broadening here.
    i = jnp.arange(f_broadened.shape[-1])
    f_broadened = f_broadened.at[..., i, i].set(0)

    # By jax convention, gradients are with respect to the complex parameters, not with
    # respect to their conjugates. Take the conjugates.
    grad_eigenvalues_conj = jnp.conj(grad_eigenvalues)
    grad_eigenvectors_conj = jnp.conj(grad_eigenvectors)

    eigenvectors_H = matrix_adjoint(eigenvectors)
    dim = eigenvalues.shape[-1]
    eye_mask = jnp.eye(dim, dtype=bool)
    eye_mask = eye_mask.reshape((1,) * (eigenvalues.ndim - 1) + (dim, dim))

    # Then, the gradient is found by equation 4.77 of [2019 Boeddeker].
    rhs = (
        diag(grad_eigenvalues_conj)
        + jnp.conj(f_broadened) * (eigenvectors_H @ grad_eigenvectors_conj)
        - jnp.conj(f_broadened)
        * (eigenvectors_H @ eigenvectors)
        @ jnp.where(eye_mask, jnp.real(eigenvectors_H @ grad_eigenvectors_conj), 0.0)
    ) @ eigenvectors_H
    grad_matrix = jnp.linalg.solve(eigenvectors_H, rhs)

    # Take the conjugate of the gradient, reverting to the jax convention
    # where gradients are with respect to complex parameters.
    grad_matrix = jnp.conj(grad_matrix)

    # Return `grad_matrix`, and `None` for the gradient with respect to `eps`.
    return grad_matrix, None


eig.defvjp(_eig_fwd, _eig_bwd)

In [175]:
import scipy
import scipy.linalg
lqr(A, B, Q, R)[1]

array([[1.37841423, 0.41421356],
       [0.41421356, 0.68179283]])

In [229]:
import jax
import jax.numpy as jnp
# from scipy.linalg import solve_continuous_are
import numpy as np
from jax.scipy.linalg import schur, inv
def solve_continuous_are(a, b, q, r):
    # Equivalent of np.dot(b, np.dot(g, b.conj().T))
    g = inv(r)
    g = jnp.dot(jnp.dot(b, g), b.conj().T)

    # Construct the blocks of the Hamiltonian matrix
    z11 = a
    z12 = -1.0 * g
    z21 = -1.0 * q
    z22 = -1.0 * a.conj().T

    # Combine the blocks to form the full Hamiltonian matrix
    z = jnp.vstack((jnp.hstack((z11, z12)), jnp.hstack((z21, z22))))

    # Compute the Schur decomposition
    # Note: JAX does not support sorting eigenvalues, so this is an approximation.
    # s, u = schur(z,)
    s,u,_ =scipy.linalg.schur(z, sort='lhp')

    # s, u = schur(z,)

    # Extract the top-left and bottom-left blocks of U
    (m, n) = u.shape
    u11 = u[0:m//2, 0:n//2]
    u21 = u[m//2:m, 0:n//2]

    # Compute the inverse of U11
    u11i = inv(u11)

    # Return the product of U21 and U11 inverse
    return jnp.dot(u21, u11i)

@jax.custom_vjp
def lqr_solution(A, B, Q, R):
    """
    Returns the LQR solution P (Riccati matrix) and sets up implicit differentiation.
    
    Args:
        A (jax.numpy.ndarray): State transition matrix.
        B (jax.numpy.ndarray): Control input matrix.
        Q (jax.numpy.ndarray): State cost matrix.
        R (jax.numpy.ndarray): Control cost matrix.
    
    Returns:
        P (jax.numpy.ndarray): Solution to the CARE, matrix P.
    """
    P = solve_continuous_are(A, B, Q, R)

    # Attach the custom VJP
    return P

def care_residual(P, A, B, Q, R):
    """CARE residual function, F(P; A, B, Q, R) = 0."""
    return A.T @ P + P @ A - P @ B @ jnp.linalg.inv(R) @ B.T @ P + Q

# Define backward pass for custom VJP
def lqr_solution_bwd(fwd_vars, out_grad):
    P, A, B, Q, R = fwd_vars  # Unpack saved values
    # Define the CARE residual function

    # Compute Jacobians of the residual function with respect to each argument
    dres_dp = jax.jacobian(care_residual, 0)(*fwd_vars)
    dres_da = jax.jacobian(care_residual, 1)(*fwd_vars)
    dres_db = jax.jacobian(care_residual, 2)(*fwd_vars)
    dres_dq = jax.jacobian(care_residual, 3)(*fwd_vars)
    dres_dr = jax.jacobian(care_residual, 4)(*fwd_vars)
    
    # Solve for the adjoint (Lagrange multiplier)
    adj = jnp.linalg.tensorsolve(dres_dp.T, out_grad.T)
    N = adj.ndim

    # Compute the gradients for A, B, Q, and R
    a_grad = -jnp.tensordot(dres_da.T, adj, N).T
    b_grad = -jnp.tensordot(dres_db.T, adj, N).T
    q_grad = -jnp.tensordot(dres_dq.T, adj, N).T
    q_grad = (q_grad + q_grad.T) / 2 
    r_grad = -jnp.tensordot(dres_dr.T, adj, N).T
    r_grad = (r_grad + r_grad.T) / 2 

    return (a_grad, b_grad, q_grad, r_grad)

def lqr_solution_fwd(A,B,Q,R):
    P = lqr_solution(A,B,Q,R)
    return P, (P, A, B, Q, R)

lqr_solution.defvjp(lqr_solution_fwd, lqr_solution_bwd)

# Define system matrices
A = jnp.array([[0.0, 1.0], [-1.0, -1.0]])
B = jnp.array([[0.0], [1.0]])
Q = jnp.array([[1.0, 0.0], [0.0, 1.0]])
R = jnp.array([[1.0]])

# Get the LQR solution and implicit differentiation function
P_solution = lqr_solution(A, B, Q, R)
print("LQR Solution (P):", P_solution)

from control.matlab import *
Kc, P, CLP = lqr(A, B, Q, R)

print("LQR Solution of python lqr library (P):", P)
print(jax.jacobian(lambda Q: lqr_solution(A, B, Q, R))(Q))

LQR Solution (P): [[1.3784151  0.41421378]
 [0.4142142  0.6817929 ]]
LQR Solution of python lqr library (P): [[1.37841423 0.41421356]
 [0.41421356 0.68179283]]
[[[[ 8.9190531e-01 -5.0000006e-01]
   [-5.0000006e-01  4.2044830e-01]]

  [[ 3.5355335e-01 -1.4901161e-08]
   [-1.4901161e-08  5.0121344e-08]]]


 [[[ 3.5355321e-01  1.4901161e-08]
   [ 1.4901161e-08 -5.0121344e-08]]

  [[ 2.1022405e-01 -1.5805060e-08]
   [-1.5805060e-08  2.9730177e-01]]]]


In [58]:
import jax.numpy as jnp

def lqr_continuous_time_infinite_horizon(A, B, Q, R, N):
  # Take the last dimension, in case we try to do some kind of broadcasting
  # thing in the future.
  x_dim = A.shape[-1]

  # See https://en.wikipedia.org/wiki/Linear%E2%80%93quadratic_regulator#Infinite-horizon,_continuous-time_LQR.
  A1 = A - B @ jp.linalg.solve(R, N.T)
  Q1 = Q - N @ jp.linalg.solve(R, N.T)

  # See https://en.wikipedia.org/wiki/Algebraic_Riccati_equation#Solution.
  H = jp.block([[A1, -B @ jp.linalg.solve(R, B.T)], [-Q1, -A1]])
  # print(tree_map(jp.allclose, H.T, H))
  # print(H)
  if tree_map(jp.allclose, H.T, H):
    eigvals, eigvectors = jp.linalg.eigh(H)
  else:
    eigvals, eigvectors = jp.linalg.eigh(H)
  argsort = jp.argsort(eigvals)

  ix = argsort[:x_dim]
  U = eigvectors[:, ix]
  P = U[x_dim:, :] @ jp.linalg.inv(U[:x_dim, :])

  K = jp.linalg.solve(R, (B.T @ P + N.T))
  return K, P, eigvals[ix]

# Define system matrices
A = jnp.array([[0.0, 1.0], [-1.0, -1.0]])
B = jnp.array([[0.0], [1.0]])
Q = jnp.array([[1.0, 0.0], [0.0, 1.0]])
R = jnp.array([[1.0]])
N = jnp.zeros((2,1))
# Get the LQR solution and implicit differentiation function
Kc, P, CLP = lqr_continuous_time_infinite_horizon(A, B, Q, R,N)
print("LQR Solution (P):", P)



LQR Solution (P): [[0.99999994 0.        ]
 [0.         0.41421354]]


In [60]:
from control.matlab import *
Kc, P, CLP = lqr(A, B, Q, R)

print("LQR Solution of python lqr library (P):", Kc, P, CLP)
print(CLP)

LQR Solution of python lqr library (P): [[0.41421356 0.68179283]] [[1.37841423 0.41421356]
 [0.41421356 0.68179283]] [-0.8408964+0.8408964j -0.8408964-0.8408964j]
[-0.8408964+0.8408964j -0.8408964-0.8408964j]


In [None]:
m=2
n=2
A = jnp.array([[0.0, 1.0], [-1.0, -1.0]])
B = jnp.array([[0.0], [1.0]])
Q = jnp.array([[1.0, 0.0], [0.0, 1.0]])
R = jnp.array([[1.0]])
N = jnp.zeros((2,1))
H = np.zeros((4,4))
H[:m, :m] = A
H[:m, m:2 * m] = 0.
H[:m, 2 * m:] = B
H[m:2 * m, :m] = -Q
H[m:2 * m, m:2 * m] = -A.conj().T
H[m:2 * m, 2 * m:] = 0
H[2 * m:, :m] = 0
H[2 * m:, m:2 * m] = B.conj().T
H[2 * m:, 2 * m:] = R

J = jnp.array(block_diag(np.eye(2 * m), np.zeros_like(R)))
q, r = jax.scipy.linalg.qr(H[:, -n:])
H = q[:, n:].conj().T.dot(H[:, :2 * m])
J = q[:2 * m, n:].conj().T.dot(J[:2 * m, :2 * m])



In [6]:
import scipy
import numpy as np
A = np.array([[0.0, 1.0], [-1.0, -1.0]])
B = np.array([[0.0], [1.0]])
Q = np.array([[1.0, 0.0], [0.0, 1.0]])
R = np.array([[1.0]])
N = np.zeros((2,1))
X = scipy.linalg.solve_continuous_are(A, B, Q, R)
K = np.linalg.solve(R, B.T @ X)
E, _ = np.linalg.eig(A - B @ K)
print(E)

[-0.84089642+0.84089642j -0.84089642-0.84089642j]


In [5]:
# Finite difference check for each entry of Q and R
def numerical_jacobian(f, x, epsilon=1e-4):
    """Compute numerical Jacobian for each entry of x on matrix output f(x)."""
    jacobian = jnp.zeros((f(x).shape[0], f(x).shape[1], x.shape[0], x.shape[1]))
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            x_perturb_plus = x.at[i, j].set(x[i, j] + epsilon)
            x_perturb_minus = x.at[i, j].set(x[i, j] - epsilon)
            if i != j:
                x_perturb_plus = x_perturb_plus.at[j,i].set(x_perturb_plus[j,i] + epsilon)
                x_perturb_minus = x_perturb_minus.at[j,i].set(x_perturb_minus[j,i] - epsilon) 
            f_plus = f(x_perturb_plus)
            f_minus = f(x_perturb_minus)
            jacobian = jacobian.at[:, :, i, j].set((f_plus - f_minus) / (2 * epsilon))
    return jacobian

# Define function to get P with fixed A, B
def get_P_with_Q(Q):
    return lqr_solution(A, B, Q, R)

def get_P_with_R(R):
    return lqr_solution(A, B, Q, R)

# Compute numerical gradients
numerical_grad_Q = numerical_jacobian(get_P_with_Q, Q)
numerical_grad_R = numerical_jacobian(get_P_with_R, R)

print("\nNumerical gradient of P with respect to Q:")
print(numerical_grad_Q)

print("\nNumerical gradient of P with respect to R:")
print(numerical_grad_R)


Numerical gradient of P with respect to Q:
[[[[ 0.5671382  -0.43272972]
   [-0.43272972  0.13142824]]

  [[ 0.1899898   0.1899898 ]
   [ 0.1899898  -0.05759299]]]


 [[[ 0.1899898   0.1899898 ]
   [ 0.1899898  -0.05759299]]

  [[ 0.1013279   0.1013279 ]
   [ 0.1013279   0.23558736]]]]

Numerical gradient of P with respect to R:
[[[[0.00685453]]

  [[0.00536442]]]


 [[[0.00536442]]

  [[0.04097819]]]]
