<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/tests/notebooks/SE_RB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TESTING RIEMANNIAN BROWNIAN MOTION on the SPECIAL EUCLIDEAN MANIFOLD
  * Since the package is not yet on pypi, using the dialog box below. Otherwse, on a terminal, download the repository then install locally.
  

  We show it step by step here, for other groups we will run one python script in folder tests (eg python test_so.py).


In [5]:

!pip install git+https://github.com/dnguyend/jax-rb


Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-wa1xleye
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-wa1xleye
  Resolved https://github.com/dnguyend/jax-rb to commit 1cd0274fe4a7808190deb9eee715703a4dd12d09
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jax_rb
  Building wheel for jax_rb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax_rb: filename=jax_rb-0.1.dev34+g1cd0274-py3-none-any.whl size=33202 sha256=0d4e8fb124b122dea7061165a5f2efb8b4844e3536f9e732fe97652cf0d1a535
  Stored in directory: /tmp/pip-ephem-wheel-cache-ftl7z_e4/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490
Successfully built jax_rb
Insta

In [6]:
import jax
import jax.numpy as jnp
from jax import random, jvp, grad

from jax_rb.manifolds.se_left_invariant import SELeftInvariant
from jax_rb.utils.utils import (grand, sym, rand_positive_definite)
jax.config.update("jax_enable_x64", True)

## Test that inv_g_metric is invert of g, and $g$ is the operator representing inner

In [7]:

key = random.PRNGKey(0)

n = 4

metric_mat, key = rand_positive_definite(key, (n*(n+1))//2)
mnf = SELeftInvariant(n, metric_mat)

x, key = mnf.rand_point(key)
print("check that rand_point generate a point on the manifold")

print(x[:-1, :-1]@x[:-1, :-1].T - jnp.eye(n))

v, key = mnf.rand_vec(key, x)

print("check that rand_vec generate a tangent vector")
print(sym(x[:-1, :-1].T@v[:-1, :-1]))

# check metric compatibility
va, key = mnf.rand_vec(key, x)
vb, key = mnf.rand_vec(key, x)
omg, key = mnf.rand_ambient(key)

omg1 = mnf.g_metric(x, omg)
omg2 = mnf.inv_g_metric(x, omg1)
print("This is the difference betwen omg and g^{-1}g omg")
print(omg2 - omg)
print("This is the difference betwen va^Tg omg and  inner(va, omg)")
print(jnp.sum(va*omg1) - mnf.inner(x, va, omg))



check that rand_point generate a point on the manifold
[[ 4.44089210e-16 -4.44089210e-16 -3.53883589e-16 -1.52655666e-16]
 [-4.44089210e-16  4.44089210e-16  1.11022302e-16 -2.77555756e-17]
 [-3.53883589e-16  1.11022302e-16  4.44089210e-16  1.11022302e-16]
 [-1.52655666e-16 -2.77555756e-17  1.11022302e-16 -1.11022302e-16]]
check that rand_vec generate a tangent vector
[[ 7.77156117e-16  1.11022302e-16 -8.32667268e-17 -2.08166817e-16]
 [ 1.11022302e-16 -5.55111512e-16 -3.19189120e-16 -1.66533454e-16]
 [-8.32667268e-17 -3.19189120e-16 -2.22044605e-16  5.55111512e-17]
 [-2.08166817e-16 -1.66533454e-16  5.55111512e-17  1.11022302e-16]]
This is the difference betwen omg and g^{-1}g omg
[[ 4.16333634e-14 -4.22994972e-14  2.31481501e-13 -3.76365605e-14
  -2.38420395e-14]
 [-1.98063788e-13  3.75255382e-14 -2.24265051e-13 -1.51406665e-13
  -2.05391260e-14]
 [ 1.72084569e-13  1.22624133e-13 -2.16576757e-13 -4.25659508e-13
   6.86672941e-14]
 [ 1.70974346e-14 -5.32907052e-14  5.41788836e-14 -2.213

# test retraction

In [8]:
xa = mnf.retract(x, va)
xa[:-1, :-1].T@xa[:-1, :-1]

Array([[ 1.00000000e+00, -1.22124533e-15, -1.53262819e-15,
         1.45022883e-15],
       [-1.22124533e-15,  1.00000000e+00,  1.66533454e-16,
        -8.81239526e-16],
       [-1.53262819e-15,  1.66533454e-16,  1.00000000e+00,
        -4.99600361e-16],
       [ 1.45022883e-15, -8.81239526e-16, -4.99600361e-16,
         1.00000000e+00]], dtype=float64)

Now check that the projection and connection are metric compatible.

The first print shows $\langle\omega, v_a\rangle_{g} = \langle\Pi(x)\omega, v_a\rangle_{g}$, if $v_a$ is a tangent vector at $x\in\mathcal{M}$.

The second and third prints shows
$$D_{v_a}\Pi(x)\langle v_b,  \Pi(x)v_b\rangle_{g} = 2 \langle v_b, D_{v_a}\Pi(x)v_b + \Gamma(x; v_a, v_b)\rangle_{g}
$$
The last print checks the connection represents a tangent vector
$$x^{-1}(D_{v_a}\Pi(x)v_b + \Gamma(x; v_a, v_b))\in T_I SE(n)
$$


In [9]:

print(mnf.inner(x, omg, va) - mnf.inner(x, mnf.proj(x, omg), va))

print(jvp(lambda x: mnf.inner(x, mnf.proj(x, vb), mnf.proj(x, vb)), (x,), (va,))[1])
print(2*mnf.inner(x, vb,
                  jvp(lambda x: mnf.proj(x, vb), (x,), (va,))[1]
                  + mnf.gamma(x, va, vb)))

D1 = jvp(lambda x: mnf.proj(x, vb), (x,), (va,))[1] + mnf.gamma(x, va, vb)
print(sym(x[:-1, :-1].T@D1[:-1, :-1]))



0.0
-68.72337600288301
-68.72337600290322
[[ 2.02504680e-13  2.55795385e-13 -7.95807864e-13 -5.11590770e-13]
 [ 2.55795385e-13  7.10542736e-14 -3.41060513e-13 -4.12114787e-13]
 [-7.95807864e-13 -3.41060513e-13 -7.95807864e-13  1.13686838e-13]
 [-5.11590770e-13 -4.12114787e-13  1.13686838e-13  4.54747351e-13]]


# check the stratonovich and the ito drift given in the library is the same as the summation in the main theorem.

In [10]:
def check_v0(self):
    p, i_sqrt_mat = self.shape[0], self._i_sqrt_g_mat
    vv = jnp.zeros((p, p))
    zr = jnp.zeros((p, p))
    for i in range(1, p-1):
        for j in range(i):
            eij = zr.at[i, j].set(1.).at[j, i].set(-1.)
            eij = 1/jnp.sqrt(2)*self._mat_apply(i_sqrt_mat, eij)
            vv += eij@eij + self.gamma(jnp.eye(p), eij, eij)
    for i in range(p-1):
        eij = self._mat_apply(i_sqrt_mat, zr.at[i, -1].set(1.))
        vv += eij@eij + self.gamma(jnp.eye(p), eij, eij)

    return -0.5*vv
print(check_v0(mnf))
print(mnf.v0)

def check_ito_drift(self, x):
    s1 = self.ito_drift(x)
    n, p = self.shape
    s = jnp.zeros((n, p))
    for i in range(n):
        for j in range(p):
            eij = jnp.zeros(self.shape).at[i, j].set(1.)
            s -= self.gamma(x, eij, self.proj(x, self.inv_g_metric(x, eij)))
    print(2*s1 - s)
check_ito_drift(mnf, x)


[[-2.77555756e-17  4.97379915e-14  5.79092330e-13  4.86721774e-13
  -5.50670620e-13]
 [-4.97379915e-14  6.93889390e-17  3.05533376e-13  3.46389584e-14
   1.68753900e-13]
 [-5.79092330e-13 -3.05533376e-13 -0.00000000e+00  9.37916411e-13
  -2.06057393e-13]
 [-4.83169060e-13 -3.37507799e-14 -9.37916411e-13  6.93889390e-17
   2.99316127e-13]
 [-0.00000000e+00 -0.00000000e+00 -0.00000000e+00 -0.00000000e+00
  -0.00000000e+00]]
[[-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]]
[[-7.28306304e-14  2.67341704e-13  7.44293516e-13 -1.20792265e-13
  -1.49213975e-13]
 [ 2.33035813e-13  2.08610906e-13  1.81188398e-13 -3.14415161e-13
  -1.04805054e-13]
 [-6.39488462e-14 -4.92939023e-14  4.44089210e-14 -5.70210545e-13
   6.20392626e-13]
 [-1.40332190e-13  2.50910404e-13 -4.47641924e-13 -4.59188243e-13
   3.13526982e-13]
 [-8.60422844e-16 -4.16333634e-16 -7.49400542e-16 -5.68989300e-16
   0.00000000e+00]]


* Testing the Laplacian. We construct a scalar function of degree 4, then compute the Laplace Beltrami operator from the definition (as a summation
$\sum \langle \xi_i, \nabla_{\xi_i} rgrad_f\rangle_{g}
$ over a locally orthogonal basis $\xi_i$, versus the Laplace-Beltrami operator given by the library.

In [11]:
# now test Laplacian

f1, key = grand(key, (n+1, n+1))
f2, key = grand(key, ((n+1)**2, (n+1)**2))
f3, key = grand(key, ((n+1)**2, (n+1)**2))

@jax.jit
def f(U):
    return jnp.sum(f1*U) + jnp.sum(U.reshape(-1)*(f2@U.reshape(-1))) \
        + jnp.sum(U.reshape(-1)*(f3@U.reshape(-1)))**2

egradf = jax.jit(jax.grad(f))

@jax.jit
def ehessf(U, omg):
    return jvp(egradf, (U,), (omg,))[1]

def lb_test(self, x, f):
    n1 = self.shape[0]
    isqrt2 = 1/jnp.sqrt(2)
    ret = 0
    rgradf = jax.jit(lambda x: self.proj(
        x,
        self.inv_g_metric(x, grad(f)(x))))
    zr = jnp.zeros((n1, n1))
    for i in range(1, n1-1):
        for j in range(i):
            vij = self.left_invariant_vector_field(x,
                                                    isqrt2*zr.at[i, j].set(1.).at[j, i].set(-1))
            tmp = jvp(rgradf, (x,), (vij,))
            nxi = tmp[1] + self.gamma(x, vij, tmp[0])
            ret += self.inner(x, vij, nxi)
    for i in range(n1-1):
        vij = self.left_invariant_vector_field(x, zr.at[i, -1].set(1.))
        tmp = jvp(rgradf, (x,), (vij,))
        nxi = tmp[1] + self.gamma(x, vij, tmp[0])
        ret += self.inner(x, vij, nxi)

    return ret

print(lb_test(mnf, x, f))
print(mnf.laplace_beltrami(x, egradf(x), ehessf))





2098.291025538745
2098.2910255389975
