<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/examples/JAX_RB_Cookbook.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Jax_RB




## Riemannian Brownian motion using JAX
This notebook introduces the basic structure of JAX_RB, a package for Brownian motion on Riemannian manifolds.

1. We implement the basic framework for embedded geometry on several manifolds through a base class, *GlobalManifold*. Here, if the user provides basic differential geometric structures, including the *metric compatible projection*, *the Levi-Civita connection*, a *tubular retraction* from the ambient space embedding, then *the base class will provide the Laplace-Beltrami operator*, and a simple adaption of the framework provides several methods *to simulate the  Riemannian Brownian motion*.

2.   An important subclass is matrix Lie group with a (left)-invariant metric, implemented through the derived class *MatrixLeftInvariant*, which includes $GL(n), SL(n), Aff(n)$ (affine Lie group), $SO(n), SE(n)$. The differential-geometric structure is implemented in a general framework, the user can choose an arbitrary metric at the identity, then it is straight forward to construct the differential geometric frameworks, including the *metric compatible projection* and *the Levi-Civita connection*. The Stratonovich and Ito drift are easily computed, with or without the choise of an orthogonal basis.
3. In addition to the matrix groups, we implement the differential geometric framework for the sphere, Stiefel, Grassmann, Symmetric Positive definite manifolds, hyperbolic manifold (realized on $\mathbb{R}^n$).

We will demonstrate how to use the implemented manifolds to simulate Brownian motions.

For now - the repository is private, we need to enter the credential as below. For security, consider generating a token just for use on this repository.



In [4]:
#@title Imports & Utils
import ipywidgets as widgets
from IPython.display import display
import subprocess


class credentials_input():
    """To access a private repository
    Include this snippet of codes to colab if you want to access
    a private repository
    """
    def __init__(self, repo_name):
        self.repo_name = repo_name
        self.username = widgets.Text(description='Username', value='')
        self.pwd = widgets.Password(
            description='Password', placeholder='password here')

        self.username.on_submit(self.handle_submit_username)
        self.pwd.on_submit(self.handle_submit_pwd)
        display("Use %40 for @ in email address:")
        display(self.username)

    def handle_submit_username(self, text):
        display(self.pwd)

    def handle_submit_pwd(self, text):
        username = self.username.value.replace('@', '%40')
        #  cmd = f'git clone https://{username}:{self.pwd.value}@{self.repo_name}'
        cmd = f'pip install git+https://{username}:{self.pwd.value}@{self.repo_name}'
        process = subprocess.Popen(
            cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        output, error = process.communicate()
        print(output, error)
        self.username.value, self.pwd.value = '', ''

credentials_input('github.com/dnguyend/jax-rb.git')






'Use %40 for @ in email address:'

Text(value='', description='Username')

<__main__.credentials_input at 0x79f9dca373d0>

Password(description='Password', placeholder='password here')



In [5]:
import jax
import jax.numpy as jnp
from jax import random, jvp, grad

from jax_rb.manifolds.se_left_invariant import SELeftInvariant

import jax_rb.simulation.simulator as sim
import jax_rb.simulation.matrix_group_integrator as mi
from jax_rb.utils.utils import (grand, sym, rand_positive_definite)
jax.config.update("jax_enable_x64", True)

## Animation of Riemannian Brownian motion on Special Euclidean manifold

Import the animation library then create the animation

In [6]:
import itertools
import matplotlib
from matplotlib import animation
import matplotlib.pyplot as plt
from matplotlib import rc
from IPython.core.display import HTML
rc('animation', html='jshtml')
# matplotlib.use("AGG")

We start out with the 2d version. Here, we create a series of 2D-objects $x_i$, each consists of $n_f=4$ points ($f$ for feature, for $n=2$ dimension we need at least $n+1$ points). We generate random infinitestimal move $g$, calling the function geodesic\_move at the identity, then apply on $x$ repeatedly
$$ x_{i+1}\leftarrow gx_i$$

The $n_f$ points move and rotate preserving distance and angle. We choose a diagonal metric equal to one everywhere, except for the rotation speed equals to 2.5


In [7]:
n_dim = 2
n_f = 4
plot_size = 10
colors = ["r", "b", "g", "m"]

def func(num, x_arr, lines, dots):
    # ANIMATION FUNCTION
    for i in range(n_f):
        lines[i].set_data(x_arr[:num, :, i].T)
        if num > 1:
            dots[i].set_data(x_arr[num-2:num-1, :, i].T)
    return lines

key = random.PRNGKey(0)
se_dim = n_dim*(n_dim+1)//2
diag = jnp.ones(se_dim).at[0].set(plot_size/4)
se = SELeftInvariant(n_dim, jnp.diag(diag))
scale = .4
def make_brownian(key):
    # x_arr = jnp.empty((N, n_dim, n_f))

    x_0, key = grand(key, (n_dim, n_f))
    x_arr = [10*x_0]

    for i in jnp.arange(1, N):
        driver_move, key = grand(key, ((n_dim+1)**2,))
        g = mi.geodesic_move(se, jnp.eye(n_dim+1), driver_move, scale)
        x_arr.append(g[:-1, :-1]@x_arr[i-1] + g[:-1, -1][:, None])
    return jnp.array(x_arr)

N = 200
x_arr = make_brownian(key)

# GET SOME MATPLOTLIB OBJECTS
fig = plt.figure()
ax = plt.axes()
lines = [plt.plot(x_arr[:, 0, i], x_arr[:, 1, i],  c=colors[i])[0]
          for i in range(n_f)]  # For line plot

dots = [plt.plot(x_arr[:1, 0, i], x_arr[:1, 1, i],  c=colors[i], marker='o')[0]
          for i in range(n_f)]
ax.set_xlabel('X(t)')
ax.set_ylabel('Y(t)')
ax.set_title('Trajectory of a Riemannian Brownian motion son SE(2)')

# Creating the Animation object
line_ani = animation.FuncAnimation(
    fig, func, frames=N, fargs=(x_arr, lines, dots,), interval=200, blit=False)
line_ani.save('se2_animation.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
plt.close()

HTML(line_ani.to_html5_video())

Here is a 3-D version. The left-invariant metric is also diagonal, the rotation component is one, the has is diagonal with entries from 3 to $\frac{3n(n+1)}{2}$ step $3$.

In [8]:
n_dim = 3
n_f = 4
plot_size = 10
colors = ["r", "b", "g", "m"]

def func(num, x_arr, lines, dots):
    # ANIMATION FUNCTION
    for i in range(n_f):
        lines[i].set_data(x_arr[:num, :2, i].T)  # cannot set 3d data, break to two commands
        lines[i].set_3d_properties(x_arr[:num, 2, i])
        if num > 1:
            dots[i].set_data(x_arr[num-2:num-1, :2, i].T)
            dots[i].set_3d_properties(x_arr[num-2:num-1, 2, i].T)
    return lines

# THE DATA POINTS
key = random.PRNGKey(0)
se_dim = n_dim*(n_dim+1)//2
# diag = jnp.ones(se_dim).at[2].set(plot_size/4).at[4].set(plot_size/4).at[5].set(plot_size/4)*10
diag = jnp.arange(1, se_dim+1)*3

se = SELeftInvariant(n_dim, jnp.diag(diag))
scale = .5

def make_brownian(key):
    x_0, key = grand(key, (n_dim, n_f))
    x_arr = [10*x_0]

    for i in jnp.arange(1, N):
        driver_move, key = grand(key, ((n_dim+1)**2,))
        g = mi.geodesic_move(se, jnp.eye(n_dim+1), driver_move, scale)
        x_arr.append(g[:-1, :-1]@x_arr[i-1] + g[:-1, -1][:, None])
    return jnp.array(x_arr)
N = 200
x_arr = make_brownian(key)

# GET SOME MATPLOTLIB OBJECTS
fig = plt.figure()
ax = plt.axes(projection='3d')
lines = [plt.plot(x_arr[:, 0, i], x_arr[:, 1, i],
                  x_arr[:, 2, i], c=colors[i])[0]
          for i in range(n_f)]  # For line plot

dots = [plt.plot(x_arr[:1, 0, i], x_arr[:1, 1, i],
                  x_arr[:1, 2, i],
                  c=colors[i], marker='o')[0]
          for i in range(n_f)]

ax.set_xlabel('X(t)')
ax.set_ylabel('Y(t)')
ax.set_zlabel('Z(t)')
ax.set_title('Trajectory of a Riemannian Brownian motion on SE(3)')

# Creating the Animation object
line_ani = animation.FuncAnimation(
    fig, func, frames=N, fargs=(x_arr, lines, dots), interval=50, blit=False)
line_ani.save('se3_animation.mp4', fps=30, extra_args=['-vcodec', 'libx264'])
plt.close()

HTML(line_ani.to_html5_video())

* A similar simulation for the affine group. To avoid distraction, we work with a 2d example. The green point below is the middle point of the red and blue point, and the affine group preserves this property.

In [11]:
from jax_rb.manifolds import AffineLeftInvariant

rc('animation', html='jshtml')
n_dim = 2
n_f = 4
colors = ["r", "b", "m", "g"]

def func(num, x_arr, lines, dots):
    # ANIMATION FUNCTION
    for i in range(n_f):
        lines[i].set_data(x_arr[:num, :, i].T)  # cannot set 3d data, break to two commands
        if num > 0:
            dots[i].set_data(x_arr[num-2:num-1, :2, i].T)
    return lines

# THE DATA POINTS
key = random.PRNGKey(0)
af_dim = n_dim*(n_dim+1)
# diag = jnp.ones(se_dim).at[2].set(plot_size/4).at[4].set(plot_size/4).at[5].set(plot_size/4)
diag = jnp.arange(af_dim)*70. + 10

aff = AffineLeftInvariant(n_dim, jnp.diag(diag))
scale = .5

# lin_comb = jnp.array([[.25, .25, 0.25, .25], [1/2, .5, 0., 0.]]).T
lin_comb = jnp.array([[1/2, .5, 0.]]).T

def make_brownian(key):
    x_0, key = grand(key, (n_dim, n_f))
    x_0 = x_0.at[:, n_dim+1:n_f].set(x_0[:, :n_dim+1]@lin_comb)
    x_arr = [10*x_0]

    for i in jnp.arange(1, N):
        driver_move, key = grand(key, ((n_dim+1)**2,))
        g = mi.geodesic_move(aff, jnp.eye(n_dim+1), driver_move, scale)
        x_arr.append(g[:-1, :-1]@x_arr[i-1] + g[:-1, -1][:, None])
    return jnp.array(x_arr)
N = 200
x_arr = make_brownian(key)

# GET SOME MATPLOTLIB OBJECTS
fig = plt.figure()
ax = plt.axes()
lines = [plt.plot(x_arr[:, 0, i], x_arr[:, 1, i],
                  c=colors[i])[0]
          for i in range(n_f)]  # For line plot

dots = [plt.plot(x_arr[:1, 0, i], x_arr[0, 1, i],
                  c=colors[i], marker='o')[0]
          for i in range(n_f)]

ax.set_xlabel('X(t)')
ax.set_ylabel('Y(t)')
ax.set_title('Trajectory of Riemannian Brownian motion Aff(2)')

# Creating the Animation object
line_ani = animation.FuncAnimation(
    fig, func, frames=N, fargs=(x_arr, lines, dots), interval=50, blit=False)
plt.close()
line_ani.save('af2_animation.mp4', fps=30, extra_args=['-vcodec', 'libx264'])

HTML(line_ani.to_html5_video())

## Computing expectation

The main simulation loop is in jax_rb.simulation.simulate. A wrapper class, jax_rb.simulation.Simulator makes it easier to run several simulations and save the results.

Extensive test results are in the folder tests/run.

Here is an example running simulate on SE(3), with the pay off $(x_T)_{11}^2$ at final time and accumulated cost of $\int_0^T t\max((x_t)_{11}-\frac{1}{2}, 0)dt$ along the path

In [12]:
n = 3
t_final = 3
n_path = 1000
n_div = 700
d_coeff = .5
dimsc = 1.
metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))
mnf = SELeftInvariant(n, metric_mat)

stor = sim.Simulator(path_pay_off=lambda x, t: t*jnp.maximum(x[0, 0]-.5, 0),
                     final_pay_off=lambda x: x[0, 0]**2)

stor.run(lambda x, unit_move, scale: mi.geodesic_move(mnf, x, unit_move, scale),
         sim.RunParams(
            jnp.eye(n+1), key, t_final, n_path,
            n_div, d_coeff*dimsc,
            (n+1)**2, mnf.name(), False, f'geodesic_move'))

stor.run(lambda x, unit_move, scale: mi.rbrownian_ito_move(mnf, x, unit_move, scale),
         sim.RunParams(
            jnp.eye(n+1), key, t_final, n_path,
            n_div, d_coeff*dimsc,
            (n+1)**2, mnf.name(), False, f'ito_move'))

stor.run(lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(mnf, x, unit_move, scale),
         sim.RunParams(
            jnp.eye(n+1), key, t_final, n_path,
            n_div, d_coeff*dimsc,
            (n+1)**2, mnf.name(), False, f'stratonovich_move'))


In [13]:
[(jnp.nanmean(a[1]), a[0].run_type) for a in stor.runs]

[(Array(2.5737494, dtype=float64), 'geodesic_move'),
 (Array(2.57373077, dtype=float64), 'ito_move'),
 (Array(2.57314791, dtype=float64), 'stratonovich_move')]

## Sampling of Uniform distributions on compact manifolds

For a compact manifold, when t_final is large, the Brownian motion converges to the uniform distribution.

For a left invariant compact group, any left-invariant metric normalizes to the same Haar measure. So we can sample the uniformed Haar probability measure using long term limit of a left-invariant measure

In [14]:
n = 3
t_final = 400
n_path = 1000
n_div = 1000
d_coeff = .5
dimsc = 1.
metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))
mnf = SELeftInvariant(n, metric_mat)

ustor = sim.Simulator(path_pay_off=None,
                     final_pay_off=lambda x: x[0, 0]**2)

ustor.run(lambda x, unit_move, scale: mi.geodesic_move(mnf, x, unit_move, scale),
         sim.RunParams(
            jnp.eye(n+1), key, t_final, n_path,
            n_div, d_coeff*dimsc,
            (n+1)**2, mnf.name(), False, f'geodesic_move'))

ustor.run(lambda x, unit_move, scale: mi.rbrownian_ito_move(mnf, x, unit_move, scale),
         sim.RunParams(
            jnp.eye(n+1), key, t_final, n_path,
            n_div, d_coeff*dimsc,
            (n+1)**2, mnf.name(), False, f'ito_move'))

ustor.run(lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(mnf, x, unit_move, scale),
         sim.RunParams(
            jnp.eye(n+1), key, t_final, n_path,
            n_div, d_coeff*dimsc,
            (n+1)**2, mnf.name(), False, f'stratonovich_move'))
[(jnp.nanmean(a[1]), a[0].run_type) for a in ustor.runs]

[(Array(0.34024765, dtype=float64), 'geodesic_move'),
 (Array(0.33749215, dtype=float64), 'ito_move'),
 (Array(0.36096234, dtype=float64), 'stratonovich_move')]

In [15]:
def uniform_sample(key, shape, pay_off, n_samples):
    """ Sample the manifold uniformly. This works for quotient of SO(n)
    """
    x_all, key = grand(key, (shape[0], shape[1], n_samples))

    def do_one_point(seq):
        ei, ev = jnp.linalg.eigh(seq.T@seq)
        return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))

    s = jax.vmap(do_one_point, in_axes=2)(x_all)
    return jnp.nanmean(s)

uniform_sample(key, (n, n), ustor.final_pay_off, n_path*1000)

Array(0.33327652, dtype=float64)

The above shows we can sample the uniform distribution on a compact Riemannian manifold using the long term Brownian motion.

The rank-one modified metric is a tractable family of metrics, which in general is not homogeneous, where we can potentially apply this method.

## Conclusion:
We expect to add more manifolds to our library, including manifolds appearing in statistics, dynamical theorem, optimization and control theory, as well as generative AI.

Similar to the case of Euclidean Brownian motions, we expect Riemannian Brownian motion to be important in different area of sciences. We hope the library will faciliate the applications of Riemannian Brownian motion in real world problems.
