<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/tests/notebooks/SE_RB.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TESTING RIEMANNIAN BROWNIAN MOTION on the SPECIAL EUCLIDEAN MANIFOLD
  * Since the package is not yet on pypi, using the dialog box below. Otherwse, on a terminal, download the repository then install locally.
  

  We show it step by step here, for other groups we will run one python script in folder tests (eg python test_so.py).


In [1]:
#@title Imports & Utils
import ipywidgets as widgets
from IPython.display import display
import subprocess


class credentials_input():
    """To access a private repository
    Include this snippet of codes to colab if you want to access
    a private repository
    """
    def __init__(self, repo_name):
        self.repo_name = repo_name
        self.username = widgets.Text(description='Username', value='')
        self.pwd = widgets.Password(
            description='Password', placeholder='password here')

        self.username.on_submit(self.handle_submit_username)
        self.pwd.on_submit(self.handle_submit_pwd)
        display("Use %40 for @ in email address:")
        display(self.username)

    def handle_submit_username(self, text):
        display(self.pwd)

    def handle_submit_pwd(self, text):
        username = self.username.value.replace('@', '%40')
        #  cmd = f'git clone https://{username}:{self.pwd.value}@{self.repo_name}'
        cmd = f'pip install git+https://{username}:{self.pwd.value}@{self.repo_name}'
        process = subprocess.Popen(
            cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        output, error = process.communicate()
        print(output, error)
        self.username.value, self.pwd.value = '', ''

credentials_input('github.com/dnguyend/jax-rb.git')




'Use %40 for @ in email address:'

Text(value='', description='Username')

<__main__.credentials_input at 0x7d46c30f3970>

Password(description='Password', placeholder='password here')



In [2]:
import jax
import jax.numpy as jnp
from jax import random, jvp, grad

from jax_rb.manifolds.se_left_invariant import SELeftInvariant
from jax_rb.utils.utils import (grand, sym, rand_positive_definite)
jax.config.update("jax_enable_x64", True)

## Test that inv_g_metric is invert of g, and $g$ is the operator representing inner

In [3]:

key = random.PRNGKey(0)

n = 4

metric_mat, key = rand_positive_definite(key, (n*(n+1))//2)
mnf = SELeftInvariant(n, metric_mat)

x, key = mnf.rand_point(key)
print("check that rand_point generate a point on the manifold")

print(x[:-1, :-1]@x[:-1, :-1].T - jnp.eye(n))

v, key = mnf.rand_vec(key, x)

print("check that rand_vec generate a tangent vector")
print(sym(x[:-1, :-1].T@v[:-1, :-1]))

# check metric compatibility
va, key = mnf.rand_vec(key, x)
vb, key = mnf.rand_vec(key, x)
omg, key = mnf.rand_ambient(key)

omg1 = mnf.g_metric(x, omg)
omg2 = mnf.inv_g_metric(x, omg1)
print("This is the difference betwen omg and g^{-1}g omg")
print(omg2 - omg)
print("This is the difference betwen va^Tg omg and  inner(va, omg)")
print(jnp.sum(va*omg1) - mnf.inner(x, va, omg))



check that rand_point generate a point on the manifold
[[ 4.44089210e-16 -1.66533454e-16 -4.44089210e-16  0.00000000e+00]
 [-1.66533454e-16 -2.22044605e-16  1.52655666e-16 -1.11022302e-16]
 [-4.44089210e-16  1.52655666e-16  2.22044605e-16 -1.52655666e-16]
 [ 0.00000000e+00 -1.11022302e-16 -1.52655666e-16 -2.22044605e-16]]
check that rand_vec generate a tangent vector
[[ 6.66133815e-16 -6.66133815e-16 -3.33066907e-16 -4.44089210e-16]
 [-6.66133815e-16 -8.12717948e-16 -1.70002901e-16  1.11022302e-16]
 [-3.33066907e-16 -1.70002901e-16  5.55111512e-17  1.66533454e-16]
 [-4.44089210e-16  1.11022302e-16  1.66533454e-16  2.77555756e-16]]
This is the difference betwen omg and g^{-1}g omg
[[-4.37427872e-14  3.18967075e-13  3.94684285e-13 -2.44249065e-13
  -1.53765889e-14]
 [ 3.08642001e-13 -1.09467990e-13  3.58602037e-14 -3.81084053e-14
   2.57571742e-14]
 [-2.90878432e-14 -1.00897068e-12 -8.82155460e-13  3.19411164e-13
   1.04749542e-13]
 [ 2.35145237e-13  3.56159546e-13  1.51212376e-13 -3.861

Now check that the projection and connection are metric compatible.

The first print shows $\langle\omega, v_a\rangle_{g} = \langle\Pi(x)\omega, v_a\rangle_{g}$, if $v_a$ is a tangent vector at $x\in\mathcal{M}$.

The second and third prints shows
$$D_{v_a}\Pi(x)\langle v_b,  \Pi(x)v_b\rangle_{g} = 2 \langle v_b, D_{v_a}\Pi(x)v_b + \Gamma(x; v_a, v_b)\rangle_{g}
$$
The last print checks the connection represents a tangent vector
$$x^{-1}(D_{v_a}\Pi(x)v_b + \Gamma(x; v_a, v_b))\in T_I SE(n)
$$


In [4]:

print(mnf.inner(x, omg, va) - mnf.inner(x, mnf.proj(x, omg), va))

print(jvp(lambda x: mnf.inner(x, mnf.proj(x, vb), mnf.proj(x, vb)), (x,), (va,))[1])
print(2*mnf.inner(x, vb,
                  jvp(lambda x: mnf.proj(x, vb), (x,), (va,))[1]
                  + mnf.gamma(x, va, vb)))

D1 = jvp(lambda x: mnf.proj(x, vb), (x,), (va,))[1] + mnf.gamma(x, va, vb)
print(sym(x[:-1, :-1].T@D1[:-1, :-1]))



1.4210854715202004e-14
-68.72337600288316
-68.72337600286082
[[-1.42108547e-13 -2.55795385e-13 -6.82121026e-13 -3.97903932e-13]
 [-2.55795385e-13 -1.84741111e-13 -4.54747351e-13 -8.52651283e-14]
 [-6.82121026e-13 -4.54747351e-13  0.00000000e+00  2.27373675e-13]
 [-3.97903932e-13 -8.52651283e-14  2.27373675e-13  3.41060513e-13]]


# check the stratonovich and the ito drift given in the library is the same as the summation in the main theorem.

In [5]:
def check_v0(self):
    p, i_sqrt_mat = self.shape[0], self._i_sqrt_g_mat
    vv = jnp.zeros((p, p))
    zr = jnp.zeros((p, p))
    for i in range(1, p-1):
        for j in range(i):
            eij = zr.at[i, j].set(1.).at[j, i].set(-1.)
            eij = 1/jnp.sqrt(2)*self._mat_apply(i_sqrt_mat, eij)
            vv += eij@eij + self.gamma(jnp.eye(p), eij, eij)
    for i in range(p-1):
        eij = self._mat_apply(i_sqrt_mat, zr.at[i, -1].set(1.))
        vv += eij@eij + self.gamma(jnp.eye(p), eij, eij)

    return -0.5*vv
print(check_v0(mnf))
print(mnf.v0)

def check_ito_drift(self, x):
    s1 = self.ito_drift(x)
    n, p = self.shape
    s = jnp.zeros((n, p))
    for i in range(n):
        for j in range(p):
            eij = jnp.zeros(self.shape).at[i, j].set(1.)
            s -= self.gamma(x, eij, self.proj(x, self.inv_g_metric(x, eij)))
    print(2*s1 - s)
check_ito_drift(mnf, x)


[[-0.00000000e+00  4.10338430e-13  4.97379915e-14 -1.01252340e-13
  -3.94351218e-13]
 [-4.10338430e-13 -0.00000000e+00  1.91846539e-13  2.38031816e-13
  -1.72306613e-13]
 [-4.97379915e-14 -1.91846539e-13 -0.00000000e+00  1.91846539e-13
  -6.32383035e-13]
 [ 1.01252340e-13 -2.39808173e-13 -1.91846539e-13 -0.00000000e+00
  -2.34479103e-13]
 [-0.00000000e+00 -0.00000000e+00 -0.00000000e+00 -0.00000000e+00
  -0.00000000e+00]]
[[-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]
 [-0. -0. -0. -0. -0.]]
[[ 1.43085543e-12  5.60440583e-13  2.96473956e-12 -5.64881475e-13
  -4.93827201e-13]
 [ 2.09809947e-12  5.18141086e-13  9.14823772e-13 -1.45128354e-12
  -2.87769808e-13]
 [ 2.82440737e-13  9.94759830e-14 -1.88471461e-12 -2.51354493e-12
   2.26352270e-12]
 [ 7.67386155e-13  9.38360500e-13 -1.11199938e-12 -1.94777527e-12
   6.39044373e-13]
 [ 5.55111512e-17  6.10622664e-16  5.55111512e-17  5.27355937e-16
   0.00000000e+00]]


* Testing the Laplacian. We construct a scalar function of degree 4, then compute the Laplace Beltrami operator from the definition (as a summation
$\sum \langle \xi_i, \nabla_{\xi_i} rgrad_f\rangle_{g}
$ over a locally orthogonal basis $\xi_i$, versus the Laplace-Beltrami operator given by the library.

In [6]:
# now test Laplacian

f1, key = grand(key, (n+1, n+1))
f2, key = grand(key, ((n+1)**2, (n+1)**2))
f3, key = grand(key, ((n+1)**2, (n+1)**2))

@jax.jit
def f(U):
    return jnp.sum(f1*U) + jnp.sum(U.reshape(-1)*(f2@U.reshape(-1))) \
        + jnp.sum(U.reshape(-1)*(f3@U.reshape(-1)))**2

egradf = jax.jit(jax.grad(f))

@jax.jit
def ehessf(U, omg):
    return jvp(egradf, (U,), (omg,))[1]

def lb_test(self, x, f):
    n1 = self.shape[0]
    isqrt2 = 1/jnp.sqrt(2)
    ret = 0
    rgradf = jax.jit(lambda x: self.proj(
        x,
        self.inv_g_metric(x, grad(f)(x))))
    zr = jnp.zeros((n1, n1))
    for i in range(1, n1-1):
        for j in range(i):
            vij = self.left_invariant_vector_field(x,
                                                    isqrt2*zr.at[i, j].set(1.).at[j, i].set(-1))
            tmp = jvp(rgradf, (x,), (vij,))
            nxi = tmp[1] + self.gamma(x, vij, tmp[0])
            ret += self.inner(x, vij, nxi)
    for i in range(n1-1):
        vij = self.left_invariant_vector_field(x, zr.at[i, -1].set(1.))
        tmp = jvp(rgradf, (x,), (vij,))
        nxi = tmp[1] + self.gamma(x, vij, tmp[0])
        ret += self.inner(x, vij, nxi)

    return ret

print(lb_test(mnf, x, f))
print(mnf.laplace_beltrami(x, egradf(x), ehessf))





2098.291025538352
2098.29102553893
