<a href="https://colab.research.google.com/github/dnguyend/EmbeddedGeometry/blob/main/colab/JaxRigidBodyDynamics.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Applying the Curvature formulas to Rigidbody Dynamics.

We use [Jax](https://jax.readthedocs.io/en/latest/index.html) for numerical derivative, in particular, [jvp](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html) is the command for directional derivative.

For example, the directional derivative of $c(x, y)$, in $x$, at $(x, y)$, indirection $\xi$ is given by

## jvp(lambda z: c(z, y), (x,), (xi,))[1]


In [None]:
import jax.numpy as jnp
import jax.numpy.linalg as jla
import jax.random as random
from jax import jvp, grad, jacfwd
from jax.config import config
config.update("jax_enable_x64", True)

def sym(A):
  return 0.5*(A+A.T)

def asym(A):
  return 0.5*(A-A.T)

def genImat(sk, n):
  Imat = jnp.abs(random.normal(sk, (n , n)))
  Imat = sym(Imat)
  Imat = Imat.at[jnp.diag_indices(n)].set(1.)
  return Imat

def Iop0(A):
  return A*Imat

def Iinv0(A):
  return A/Imat

def Lie(A, B):
  return A@B - B@A

def inner(U, xi, eta):
  return jnp.sum((U.T@xi)*Iop0(U.T@eta))

def GammaU(U, xi, eta):
  return U@sym(xi.T@eta) + 0.5*U@Iinv0(Lie(U.T@xi, Iop0(U.T@eta)) + Lie(U.T@eta, Iop0(U.T@xi)))

def Pi(U, Omg):
  return U@asym(U.T@Omg)

def cz(mat):
  return jnp.max(jnp.abs(mat))


#  Check metric compatibility:
The Jax expression 
## (jvp(lambda U: inner(U, Pi(U, eta), Pi(U, eta)), (U,), (xi,))[1]##
is the directional derivative in direction $\xi$ of 
$$ \langle \Pi(U)\eta, \Pi(U)\eta, \rangle_{g, U}$$
Note $U\mapsto \Pi(U)\eta$ is a vector field. We check it is  $2\langle \eta, \Gamma(U; \xi, \eta)\rangle_{g, U}$, given by the python command
## 2*inner(U, eta, GammaU(U, xi, eta))

In [None]:
n = 3



key = random.PRNGKey(0)
key, sk = random.split(key)
U, _ = jla.qr(random.normal(sk, (n, n)))

def randvec(sk, n, U=None):
  if U is None:
    return asym(random.normal(sk, (n, n)))
  else:
    return U@asym(random.normal(sk, (n, n)))

key, sk = random.split(key)
Imat = genImat(sk, n)

key, sk = random.split(key)
xi = randvec(sk, n, U)
key, sk = random.split(key)
eta = randvec(sk, n, U)


print("CHECK GAMMA IS A CHRISTOFFEL FUNCTION - Nabla produces a vector field")
display(U.T@(jvp(lambda U: Pi(U, eta), (U,), (xi,))[1] + GammaU(U, xi, eta)))

print("CHECK GAMMA IS METRIC COMPATIBLE")
display(jvp(lambda U: inner(U, Pi(U, eta), Pi(U, eta)), (U,), (xi,))[1])
display(2*inner(U, eta, GammaU(U, xi, eta)))





CHECK GAMMA IS A CHRISTOFFEL FUNCTION - Nabla produces a vector field


Array([[-1.69776750e-16,  2.71117699e-03, -1.30355786e-01],
       [-2.71117699e-03,  1.86375182e-16,  7.98805202e-02],
       [ 1.30355786e-01, -7.98805202e-02,  5.01206066e-18]],      dtype=float64)

CHECK GAMMA IS METRIC COMPATIBLE


Array(0.26541833, dtype=float64)

Array(0.26541833, dtype=float64)

# Curvature
The curvature $D_{\xi_1}\Gamma(\xi_2, \xi_3)
- D_{\xi_2}\Gamma(\xi_2, \xi_3) +
D_{\xi_1}\Gamma(U; \xi_1, \xi_3)
+ \Gamma(\xi_1, \Gamma(\xi_2, \xi_3))
- \Gamma(\xi_2, \Gamma(\xi_1, \xi_3))
$
* We implement the directional derivatives by Jax. We verify the [Bianchi identities](https://en.wikipedia.org/wiki/Riemann_curvature_tensor#Symmetries_and_identities) below 

In [None]:
def CurvJax31(U, Xi1, Xi2, Xi3):
  DG1 = jvp(lambda q: GammaU(q, Xi2, Xi3), (U, ), (Xi1, ))[1]
  DG2 = jvp(lambda q: GammaU(q, Xi1, Xi3), (U, ), (Xi2, ))[1]
  GG1 = GammaU(U, Xi1, GammaU(U, Xi2, Xi3))
  GG2 = GammaU(U, Xi2, GammaU(U, Xi1, Xi3))

  return DG1 - DG2 + GG1 - GG2


def CurvJax4(U, Xi1, Xi2, Xi3, Xi4):
  return inner(U, CurvJax31(U, Xi1, Xi2, Xi3),Xi4)

# check Bianchi
print("CHECK BIANCHI")
key, sk = random.split(key)
Xi1 = randvec(sk, n, U)

key, sk = random.split(key)
Xi2 = randvec(sk, n, U)

key, sk = random.split(key)
Xi3 = randvec(sk, n, U)

key, sk = random.split(key)
Xi4 = randvec(sk, n, U)

display(CurvJax4(U, Xi1, Xi2, Xi3, Xi4))
display(CurvJax4(U, Xi2, Xi1, Xi3, Xi4))
display(CurvJax4(U, Xi1, Xi2, Xi4, Xi3))
display(CurvJax4(U, Xi3, Xi4, Xi1, Xi2))


CHECK BIANCHI


Array(0.66070969, dtype=float64)

Array(-0.66070969, dtype=float64)

Array(-0.66070969, dtype=float64)

Array(0.66070969, dtype=float64)

For the Second Bianchi identities we need $\nabla R$, the covariant derivative of the tensor $R$, which we implement below using jvp

In [None]:
print("CHECK THE SECOND BIANCHI Identity")
def nablaR(U, Xi1, Xi2, Xi3, Xi4):
  # extend Xi1, Xi2 to vector fields by projection then take
  # covariant derivatives.

  def cov(q, Xi1, Xi2):
      return jvp(lambda x: Pi(x, Xi2), (q,), (Xi1,))[1] \
        + GammaU(q, Xi1, Xi2)

  Tot = jvp(lambda x: CurvJax31(
      x,
    Pi(x, Xi2),
    Pi(x, Xi3),
    Pi(x, Xi4)),
    (U,), (Xi1,))[1] \
      + GammaU(U, Xi1, CurvJax31(U, Xi2, Xi3, Xi4))
  R12 = CurvJax31(U, cov(U, Xi1, Xi2), Xi3, Xi4)
  R13 = CurvJax31(U, Xi2, cov(U, Xi1, Xi3), Xi4)
  R14 = CurvJax31(U, Xi2, Xi3, cov(U, Xi1, Xi4))
  return Tot - R12 - R13 - R14

f1 = nablaR(U, Xi1, Xi2, Xi3, Xi4)
f2 = nablaR(U, Xi2, Xi3, Xi1, Xi4)
f3 = nablaR(U, Xi3, Xi1, Xi2, Xi4)

print(f1 + f2 + f3)


CHECK THE SECOND BIANCHI Identity
[[-5.13478149e-16  1.28369537e-15 -1.01134379e-15]
 [-7.77156117e-16 -2.63677968e-16 -5.96744876e-16]
 [-9.36750677e-17  2.77555756e-17  8.32667268e-17]]


$\DeclareMathOperator{\ad}{ad}$
$\newcommand{\cI}{\mathcal{I}}$
$\DeclareMathOperator{\adI}{ad_{\cI}}$
$\newcommand{\mrGamma}{\mathring{\Gamma}}$
# Comparing with Euler-Arnold's type formula for curvature of $SE(n)$ with left-invariant metric
* Only the $SO(n)$ component contributes to the metric.
* Euler-Arnold: for antisymmetric matrices $A, B, C$ define the $\cI$-bracket
$$[A, B]_{\cI} := [A, B] + \cI^{-1}[A, \cI(B)] + \cI^{-1}[B, \cI(A)]$$
* Curvature using Euler-Arnold:
$$R = -\frac{1}{2}[[A, B], C]_{\cI} + \frac{1}{4}[A, [B,C]_{\cI}]_{\cI} -
\frac{1}{4}[B[A,C]_{\cI}]_{\cI}
$$
Set $\adI(A, B) := \cI^{-1}[A,\cI(B)]$
* $[A, B]_{\cI} = [A, B] + \adI(A, B) +\adI(B, A)$

For tangent vectors $A, B, C$ to $SO(n)$ at $I_n$:
* $\Gamma = -(AB)_{sym} +AB_{sym} +\frac{1}{2}(\adI(B, A) + \adI(A, B))$
* $\mrGamma = \frac{1}{2}(\adI(B, A) + \adI(A, B))$
* $D_A\mrGamma(B, C) = \frac{1}{2}A\{\adI(B, C)+ \adI(C, B)\} +
\frac{1}{2}\cI^{-1}\{[-AB, \cI(C)] + [B,-\cI(AC)] + [-AC, \cI(B)] +[C, -\cI(AB)]\}$
* Curvature using the embedded method:
$$R^{Emb}_{AB}C = D_A\mrGamma(B, C) - D_B\mrGamma(A, C)
+ \Gamma(A, \Gamma(B, C)) - \Gamma(B, \Gamma(A, C))$$

We verify both curvature formulas agree numerically below. Presumably we can verify they agree algebraically for another confirmation - but both formulas are long when expanded.

The sign convention for curvature in the reference below is opposite of the sign convention used here.

[1] D. Nguyen, Curvatures of Stiefel manifolds with deformation metrics, Journal of Lie Theory 32 (2022) 563–600.

In [None]:
def Iop(Imat, A):
  return A*Imat

def Iinv(Imat, A):
  return A/Imat


def adI(Imat, A, B):
  return Iinv(Imat, Lie(A, Iop(Imat, B)))

def LieI(Imat, A, B):
  return Lie(A, B) + adI(Imat, A, B) + adI(Imat, B, A)

def mrGamma(Imat, A, B):
  return 0.5*(adI(Imat, A, B) + adI(Imat, B, A))

def Gamma(Imat, A, B):
  return -sym(A@B) + A@sym(B) + 0.5*(adI(Imat, A, B) + adI(Imat, B, A))

def DmrGamma(Imat, A, B, C):
  return 0.5*A@adI(Imat, B, C) \
    + 0.5*A@adI(Imat, C, B) \
    + 0.5*adI(Imat, -A@B, C) \
    + 0.5*adI(Imat, B, -A@C) \
    + 0.5*adI(Imat, -A@C, B) \
    + 0.5*adI(Imat, C, -A@B)

def CurvI(Imat, A, B, C):
  return -0.5*LieI(Imat, Lie(A, B), C) \
    + 0.25*LieI(Imat, A, LieI(Imat, B, C)) \
    - 0.25*LieI(Imat, B, LieI(Imat, A, C))

def CurvEmbedded(Imat, A, B, C):
  return DmrGamma(Imat, A, B, C) \
    - DmrGamma(Imat, B, A, C) \
    + Gamma(Imat, A, Gamma(Imat, B, C)) \
    - Gamma(Imat, B, Gamma(Imat, A, C))


# Compare the numerical curvature by Jax with the curvature by Arnold's method

In [None]:
print(CurvJax31(U, Xi1, Xi2, Xi3) -  U@CurvEmbedded(Imat, U.T@Xi1, U.T@Xi2, U.T@Xi3))

[[ 5.55111512e-17 -5.55111512e-17 -1.38777878e-17]
 [-3.94215910e-16 -2.77555756e-17  8.32667268e-17]
 [ 1.17527515e-16  5.89805982e-17 -2.35922393e-16]]


## Test for n=2 to 10, for 10 scenarios each

In [None]:
for n in range(2, 5):
  for i in range(10):
    key, sk = random.split(key)
    A = randvec(sk, n)
    key, sk = random.split(key)
    B = randvec(sk, n)
    key, sk = random.split(key)
    C = randvec(sk, n)

    Imat = genImat(sk, n)

    c1 = CurvI(Imat, A, B, C)
    c2 = CurvEmbedded(Imat, A, B, C)
    display(cz(c1 - c2))




Array(1.11022302e-16, dtype=float64)

Array(0., dtype=float64)

Array(0., dtype=float64)

Array(0., dtype=float64)

Array(0., dtype=float64)

Array(0., dtype=float64)

Array(3.46944695e-18, dtype=float64)

Array(2.22044605e-16, dtype=float64)

Array(0., dtype=float64)

Array(0., dtype=float64)

Array(2.77555756e-16, dtype=float64)

Array(3.55271368e-15, dtype=float64)

Array(6.9388939e-17, dtype=float64)

Array(4.4408921e-16, dtype=float64)

Array(2.22044605e-16, dtype=float64)

Array(8.8817842e-16, dtype=float64)

Array(4.71844785e-16, dtype=float64)

Array(3.67761377e-16, dtype=float64)

Array(2.22044605e-16, dtype=float64)

Array(2.22044605e-16, dtype=float64)

Array(8.8817842e-16, dtype=float64)

Array(1.55431223e-15, dtype=float64)

Array(1.24344979e-14, dtype=float64)

Array(1.33226763e-15, dtype=float64)

Array(8.8817842e-16, dtype=float64)

Array(3.99680289e-15, dtype=float64)

Array(2.77555756e-16, dtype=float64)

Array(2.44249065e-15, dtype=float64)

Array(4.71844785e-16, dtype=float64)

Array(1.77635684e-15, dtype=float64)