In [118]:
import jax.numpy as jnp
from jax.scipy.linalg import block_diag
import jax

def solve_continuous_are_jax(A, B, Q, R):
    """
    Solves the continuous-time algebraic Riccati equation (CARE) using JAX.
    
    Parameters:
        A: (M, M) array
        B: (M, N) array
        Q: (M, M) array
        R: (N, N) array (nonsingular)
        e: (M, M) array, optional
        s: (M, N) array, optional
        balanced: bool, optional (default=True)
        
    Returns:
        X: (M, M) ndarray, solution to the CARE
    """
    M, N = A.shape[0], B.shape[1]
    
    # Default matrices for e and s
    e = jnp.eye(M)
    s = jnp.zeros((M, N))
    
    # Construct Hamiltonian matrix H and identity matrix J
    H = jnp.zeros((2 * M + N, 2 * M + N))
    H = H.at[:M, :M].set(A)
    H = H.at[:M, 2 * M:].set(B)
    H = H.at[M:2 * M, :M].set(-Q)
    H = H.at[M:2 * M, M:2 * M].set(jnp.conjugate(-A).T)
    H = H.at[2 * M:, M:2 * M].set(B.conjugate().T)
    H = H.at[2 * M:, 2 * M:].set(R)

    J = block_diag(jnp.eye(2*M), jnp.zeros_like(R))
    
    # Balancing step
    M_abs = jnp.abs(H) + jnp.abs(J)
    M_abs = M_abs - jnp.diag(jnp.diag(M_abs))  # Remove diagonal elements to avoid scaling them

    # Calculate scaling factors by summing row and column norms
    row_scales = jnp.sum(M_abs, axis=1)
    col_scales = jnp.sum(M_abs, axis=0)

    # Avoid division by zero in scaling by adding a small value
    scale_factors = jnp.sqrt(row_scales * col_scales)

    # Check if scaling is required
    if not jnp.allclose(scale_factors, jnp.ones_like(scale_factors)):
        # Transform scale_factors into a logarithmic scale as per the original logic
        log_scales = jnp.log2(scale_factors)
        log_scales_left = log_scales[:M]
        log_scales_right = log_scales[M:2*M]
        s = jnp.round((log_scales_right - log_scales_left) / 2)
        # Construct a scaling vector in the form [D, inv(D)] to apply element-wise
        sca = 2 ** jnp.concatenate([s, -s, log_scales[2*M:]])
        elwisescale = sca[:, None] * jnp.reciprocal(sca)

        # Apply the scaling element-wise to H and J
        H = H * elwisescale
        J = J * elwisescale
    
    
    # QR decomposition approximation for Hamiltonian matrix
    q, r = jax.scipy.linalg.qr(H[:, -N:])
    H = jnp.dot(q[:, N:].conjugate().T , H[:, :2 * M])
    J = jnp.dot(q[:2*M, N:].conjugate().T , J[:2 * M, :2 * M])
    
    # Eigenvalue decomposition of the Hamiltonian matrix
    J_inv = jnp.linalg.inv(J)
    H_transformed = J_inv @ H
    eigvals, eigvecs = jnp.linalg.eigh(H_transformed)
    stable_indices = jnp.where(jnp.real(eigvals) < 0)[0]  # Stable subspace

    # Select relevant eigenvectors for the stable subspace
    U = eigvecs[:, stable_indices]
    U00 = U[:M, :M]
    U10 = U[M:, :M]
    print(U00)
    up, ul, uu =  jax.scipy.linalg.lu(U00)
    x = jax.scipy.linalg.solve_triangular(ul.conjugate().T, 
                                          jax.scipy.linalg.solve_triangular(uu.conjugate().T, U10.conjugate().T, lower=True),
                                          unit_diagonal=True
                                        ).conjugate().T.dot(up.conjugate().T)
    x = x * sca[:M, None] * sca[:M]

    # Symmetrize the solution
    return (x + x.conjugate().T) / 2

In [119]:
import jax
import jax.numpy as jnp
# from scipy.linalg import solve_continuous_are
from jax import custom_jvp
import numpy as np
from scipy.linalg import block_diag, lu, solve_triangular, matrix_balance, qr, ordqz
from numpy.linalg import inv, LinAlgError, norm, cond, svd

def solve_continuous_are(a, b, q, r, e=None, s=None, balanced=True):
    a, b, q, r, e, s, m, n, r_or_c, gen_are = _are_validate_args(a, b, q, r, e, s, 'care')

    H = np.empty((2 * m + n, 2 * m + n), dtype=r_or_c)
    H[:m, :m] = a
    H[:m, m:2 * m] = 0.
    H[:m, 2 * m:] = b
    H[m:2 * m, :m] = -q
    H[m:2 * m, m:2 * m] = -a.conj().T
    H[m:2 * m, 2 * m:] = 0. if s is None else -s
    H[2 * m:, :m] = 0. if s is None else s.conj().T
    H[2 * m:, m:2 * m] = b.conj().T
    H[2 * m:, 2 * m:] = r

    if gen_are and e is not None:
        J = block_diag(e, e.conj().T, np.zeros_like(r, dtype=r_or_c))
    else:
        J = block_diag(np.eye(2 * m), np.zeros_like(r, dtype=r_or_c))

    if balanced:
        M = np.abs(H) + np.abs(J)

        np.fill_diagonal(M, 0.)
        _, (sca, _) = matrix_balance(M, separate=1, permute=0)
        if not np.allclose(sca, np.ones_like(sca)):
            sca = np.log2(sca)
            s = np.round((sca[m:2 * m] - sca[:m]) / 2)
            sca = 2 ** np.r_[s, -s, sca[2 * m:]]
            elwisescale = sca[:, None] * np.reciprocal(sca)
            H *= elwisescale
            J *= elwisescale
    
    q, r = qr(H[:, -n:])
    H = q[:, n:].conj().T.dot(H[:, :2 * m])
    J = q[:2 * m, n:].conj().T.dot(J[:2 * m, :2 * m])
    
    out_str = 'real' if r_or_c == float else 'complex'
    _, _, _, _, _, u = ordqz(H, J, sort='lhp', overwrite_a=True, overwrite_b=True, check_finite=False, output=out_str)

    if e is not None:
        u, _ = qr(np.vstack((e.dot(u[:m, :m]), u[m:, :m])))
    u00 = u[:m, :m]
    u10 = u[m:, :m]
    print(u00)
    up, ul, uu = lu(u00)
    if 1 / cond(uu) < np.spacing(1.):
        raise np.linalg.LinAlgError('Failed to find a finite solution.')

    x = solve_triangular(ul.conj().T, solve_triangular(uu.conj().T, u10.conj().T, lower=True), unit_diagonal=True).conj().T.dot(up.conj().T)
    if balanced:
        x *= sca[:m, None] * sca[:m]

    u_sym = u00.conj().T.dot(u10)
    n_u_sym = norm(u_sym, 1)
    u_sym = u_sym - u_sym.conj().T
    sym_threshold = np.max([np.spacing(1000.), 0.1 * n_u_sym])

    if norm(u_sym, 1) > sym_threshold:
        raise np.linalg.LinAlgError('The associated Hamiltonian pencil has eigenvalues too close to the imaginary axis')

    return (x + x.conj().T) / 2

def _are_validate_args(a, b, q, r, e, s, eq_type):
    m, n = a.shape[0], b.shape[1]
    r_or_c = np.common_type(a, b, q, r)
    gen_are = e is not None or s is not None
    if e is None:
        e = np.eye(m, dtype=r_or_c)
    if s is None:
        s = np.zeros((m, n), dtype=r_or_c)
    return a, b, q, r, e, s, m, n, r_or_c, gen_are


def solve_care(A, B, Q, R):
    """
    Wrapper to solve the continuous-time Algebraic Riccati Equation (CARE).
    This uses SciPy's solver but converts results to JAX arrays.
    """
    P = solve_continuous_are(A, B, Q, R)
    return jnp.array(P)

@jax.custom_vjp
def lqr_solution(A, B, Q, R):
    """
    Returns the LQR solution P (Riccati matrix) and sets up implicit differentiation.
    
    Args:
        A (jax.numpy.ndarray): State transition matrix.
        B (jax.numpy.ndarray): Control input matrix.
        Q (jax.numpy.ndarray): State cost matrix.
        R (jax.numpy.ndarray): Control cost matrix.
    
    Returns:
        P (jax.numpy.ndarray): Solution to the CARE, matrix P.
    """
    P = solve_care(A, B, Q, R)

    # Attach the custom VJP
    return P

def care_residual(P, A, B, Q, R):
    """CARE residual function, F(P; A, B, Q, R) = 0."""
    return A.T @ P + P @ A - P @ B @ jnp.linalg.inv(R) @ B.T @ P + Q

# Define backward pass for custom VJP
def lqr_solution_bwd(fwd_vars, out_grad):
    P, A, B, Q, R = fwd_vars  # Unpack saved values
    # Define the CARE residual function

    # Compute Jacobians of the residual function with respect to each argument
    dres_dp = jax.jacobian(care_residual, 0)(*fwd_vars)
    dres_da = jax.jacobian(care_residual, 1)(*fwd_vars)
    dres_db = jax.jacobian(care_residual, 2)(*fwd_vars)
    dres_dq = jax.jacobian(care_residual, 3)(*fwd_vars)
    dres_dr = jax.jacobian(care_residual, 4)(*fwd_vars)
    
    # Solve for the adjoint (Lagrange multiplier)
    adj = jnp.linalg.tensorsolve(dres_dp.T, out_grad.T)
    N = adj.ndim

    # Compute the gradients for A, B, Q, and R
    a_grad = -jnp.tensordot(dres_da.T, adj, N).T
    b_grad = -jnp.tensordot(dres_db.T, adj, N).T
    q_grad = -jnp.tensordot(dres_dq.T, adj, N).T
    q_grad = (q_grad + q_grad.T) / 2 
    r_grad = -jnp.tensordot(dres_dr.T, adj, N).T
    r_grad = (r_grad + r_grad.T) / 2 

    return (a_grad, b_grad, q_grad, r_grad)

def lqr_solution_fwd(A,B,Q,R):
    P = lqr_solution(A,B,Q,R)
    return P, (P, A, B, Q, R)

lqr_solution.defvjp(lqr_solution_fwd, lqr_solution_bwd)



In [120]:
# Define system matrices
A = jnp.array([[0.0, 1.0], [-1.0, -1.0]])
B = jnp.array([[0.0], [1.0]])
Q = jnp.array([[1.0, 0.0], [0.0, 1.0]])
R = jnp.array([[1.0]])

# Get the LQR solution and implicit differentiation function
P_solution = lqr_solution(A, B, Q, R)
print("LQR Solution (P):", P_solution)

from control.matlab import *
Kc, P, CLP = lqr(A, B, Q, R)

# print("LQR Solution of python lqr library (P):", P)
print("LQR Solution of python jax (P):", solve_continuous_are_jax(A,B,Q,R))

# Analytical gradient of P with respect to Q and R (full Jacobian)
# analytical_grad_Q = jax.jacobian(lambda Q: lqr_solution(A, B, Q, R))(Q)
# analytical_grad_R = jax.jacobian(lambda R: lqr_solution(A, B, Q, R))(R)

# print("\nAnalytical Jacobian of P with respect to Q:")
# print(analytical_grad_Q)

# print("\nAnalytical Jacobian of P with respect to R:")
# print(analytical_grad_R)

[[-0.47035969-0.17856214j -0.25435849+0.250803j  ]
 [ 0.24537151+0.54567603j -0.44688663-0.39617092j]]
LQR Solution (P): [[1.3784142 +0.0000000e+00j 0.41421357+1.3877788e-17j]
 [0.41421357-1.3877788e-17j 0.6817928 +0.0000000e+00j]]
[[0.         0.70710677]
 [0.9238795  0.        ]]
LQR Solution of python jax (P): [[0.99999994 0.        ]
 [0.         0.41421354]]


In [121]:
# Finite difference check for each entry of Q and R
def numerical_jacobian(f, x, epsilon=1e-4):
    """Compute numerical Jacobian for each entry of x on matrix output f(x)."""
    jacobian = jnp.zeros((f(x).shape[0], f(x).shape[1], x.shape[0], x.shape[1]))
    for i in range(x.shape[0]):
        for j in range(x.shape[1]):
            x_perturb_plus = x.at[i, j].set(x[i, j] + epsilon)
            x_perturb_minus = x.at[i, j].set(x[i, j] - epsilon)
            if i != j:
                x_perturb_plus = x_perturb_plus.at[j,i].set(x_perturb_plus[j,i] + epsilon)
                x_perturb_minus = x_perturb_minus.at[j,i].set(x_perturb_minus[j,i] - epsilon) 
            f_plus = f(x_perturb_plus)
            f_minus = f(x_perturb_minus)
            jacobian = jacobian.at[:, :, i, j].set((f_plus - f_minus) / (2 * epsilon))
    return jacobian

# Define function to get P with fixed A, B
def get_P_with_Q(Q):
    return lqr_solution(A, B, Q, R)

def get_P_with_R(R):
    return lqr_solution(A, B, Q, R)

# Compute numerical gradients
numerical_grad_Q = numerical_jacobian(get_P_with_Q, Q)
numerical_grad_R = numerical_jacobian(get_P_with_R, R)

print("\nNumerical gradient of P with respect to Q:")
print(numerical_grad_Q)

print("\nNumerical gradient of P with respect to R:")
print(numerical_grad_R)

[[-0.47035969-0.17856214j -0.25435849+0.250803j  ]
 [ 0.24537151+0.54567603j -0.44688663-0.39617092j]]
[[-0.47035969-0.17856214j -0.25435849+0.250803j  ]
 [ 0.24537151+0.54567603j -0.44688663-0.39617092j]]
[[-0.47044318-0.17830651j -0.25417242+0.25096848j]
 [ 0.24565976+0.54553811j -0.44713503-0.39589326j]]
[[-0.47027605-0.17881775j -0.25454447+0.25063737j]
 [ 0.24508318+0.54581381j -0.44663801-0.39644847j]]
[[-0.47019983-0.17901955j -0.2544548 +0.25074617j]
 [ 0.24485246+0.54592624j -0.44675897-0.39629021j]]
[[-0.47051909-0.17810464j -0.25426213+0.25085984j]
 [ 0.24589027+0.54542536j -0.44701431-0.39605154j]]
[[-0.47019983-0.17901955j -0.2544548 +0.25074617j]
 [ 0.24485246+0.54592624j -0.44675897-0.39629021j]]
[[-0.47051909-0.17810464j -0.25426213+0.25085984j]
 [ 0.24589027+0.54542536j -0.44701431-0.39605154j]]
[[-0.47040562-0.17842352j -0.25404272+0.25110764j]
 [ 0.24553636+0.54559375j -0.44737582-0.39560834j]]
[[-0.47031371-0.17870075j -0.25467388+0.25049798j]
 [ 0.24520665+0.545758

  return lax_internal._convert_element_type(out, dtype, weak_type)
