# The Lie group exponential map and its second order approximations
  * Since the package is not yet on pypi, use the dialog box below. Otherwise, on a terminal, download the repository then install locally.
  
  

In [1]:
!pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-lsy3c8_z
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-lsy3c8_z
  Resolved https://github.com/dnguyend/jax-rb to commit 581cc9d9b79fd59e4e49f03ca352f9b35c65ae65
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jax_rb
  Building wheel for jax_rb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax_rb: filename=jax_rb-0.1.dev57+g581cc9d-py3-none-any.whl size=33706 sha256=c4c979ea7f80ff92e40bcf9dea861d559273b4ea6131973c353c83213783be2a
  Stored in directory: /tmp/pip-ephem-wheel-cache-tsmvnpug/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490
Successfully built jax_rb
Installing collected packages: jax_rb
Successfully installed

$\newcommand{\R}{\mathbb{R}}$
$\newcommand{\fR}{\mathfrak{r}}$
$\newcommand{\so}{\mathfrak{so}}$
$\newcommand{\expm}{\mathsf{expm}}$
$\newcommand{\sigmam}{\mathring{\sigma}}$

## The exponential retraction and the Cayley transform retraction.
For a matrix Lie group $G$, the exponential retraction is given by
$$\fR(x, v) = x\expm(x^{-1}v)
$$
where $v\in T_xG$. We have the expansion
$$\fR(x, hv) = x + hv + \frac{h^2}{2}x(x^{-1}v)^2 + O(h^3)
$$
Thus, the adjusted drift will be $\mu_{\fR} = \mu - \frac{1}{2}x(x^{-1}\sigmam(x; E_{ij}))^2$ for the equation:
$$dX = \mu dt + \sigmam(X)dW_t.$$
For a random move $\Delta_W \sim N(0, h^{\frac{1}{2}}I_{\R^{N\times N}})$, the  Euler-Maruyama exponential step will be
$$X_{i+1} = X_i\expm(X_i^{-1}(h\mu_r(X_i)+ \sigmam(X)\Delta_W ))
$$
It is a [remarkable fact](https://en.wikipedia.org/wiki/Pad%C3%A9_approximant#cite_note-wolfram-alpha-pade-exp-11)  that the diagonal Pade approximator of $e^x$ is a rational function of the form $\frac{p(x)}{p(-x)}$, with the first approximation corresponds to $p(x) = 1+\frac{x}{2}$.

For the group SO(N), or (more generally, for quadratic Lie group [Celledoni and Iserle]), for $a\in \so(N)$, we have $p(-a)^{-1}p(a)$ is in $SO(N)$ for all analytic $p$ with real coefficients. With $p(x) = 1+\frac{x}{2}$, we have the Cayley retraction

$$\fR_{Cayley}(x, v) = x(I-\frac{1}{2}x^{-1}v)^{-1}(I+\frac{1}{2}x^{-1}v),
$$
and the Euler-Maruyama Cayle steps will be
$$X_{i+1} = X_i(I - \frac{1}{2}X_i^{-1}(h\mu_r(X_i)+ \sigmam(X)\Delta_W ))^{-1}
(I + \frac{1}{2}X_i^{-1}(h\mu_r(X_i)+ \sigmam(X)\Delta_W ))
$$
We will check that these steps give the same simulation results as the geodesic, Ito and Stratonovich integrator in the paper. Beyond SO(n), when the group is not compact, the error growth is difficult to control for long term simulations.


### References

[Celledoni and Iserle] Celledoni, Elena, and Arieh Iserles. “Approximating the Exponential from a Lie Algebra to a Lie Group.” Mathematics of Computation, vol. 69, no. 232, 2000, pp. 1457–80. JSTOR, http://www.jstor.org/stable/2585076. Accessed 17 July 2024.


In [2]:
!pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-cvhpd3gb
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-cvhpd3gb
  Resolved https://github.com/dnguyend/jax-rb to commit 581cc9d9b79fd59e4e49f03ca352f9b35c65ae65
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [3]:
"""test simulation with the
"""
from functools import partial
from time import perf_counter

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import random, vmap, jit
from jax.scipy.linalg import expm
import jax_rb.manifolds.so_left_invariant as som

from jax_rb.utils.utils import (rand_positive_definite, sym, vcat)
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.matrix_group_integrator as mi


jax.config.update("jax_enable_x64", True)


## A retractive step

In [4]:
@partial(jit, static_argnums=(0,2,5,6))
def matrix_retractive_move(rtr, x, t, unit_move, scale, sigma, mu):
    """ Simulating the equation :math:`dX_t = \\mu(X_t, t) dt + \\sigma(X_t, t) dW_t` using the retraction rtr. The manifold is a Lie group.
    We do not assume a Riemanian metric on the manifold, :math:`\\sigma\\sigma^T` could be degenerated on :math:`T\\mathcal{M}`. However, we create subclasses for left-invariant Lie groups.

    W is a Wiener process driving the equation, defined on :math:`\\mathbb{R}^k`. W is given by unit_move.

    :math:`\\sigma(X_t, t)` maps :math:`\\mathbb{R}^k` to :math:`\\mathcal{E}`, but the image belongs
    to :math:`T_{X_t}\\mathcal{M}`.

    The retraction rtr is assume to have the method :math:`\\text{drift_adj}` for an adjustment.

    The move is :math:`x_{new} = \\mathfrak{r}(x, \\Pi(x)\\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}) + \\text{scale} (\\mu + \\text{drift_adj}))`.

    :param rtr: the retraction,
    :param x: a point on the manifold,
    :param t: time
    :param unit_move: a random normal draw
    :param scale: scaling
    :param sigma: a function implementing the map :math:`\\sigma`
    :param mu: a function implementing the Ito drift :math:`\\mu`
    """
    return rtr.retract(x,
                       sigma(x, t, unit_move.reshape(x.shape))*jnp.sqrt(scale)
                       + scale*(mu(x, t) + rtr.drift_adjust(sigma, x, t, unit_move.shape[0])))


A few classes for retraction on groups and specialized implementations for $SO(n)$. Here, we are in the case of Riemannian Brownian.

In [5]:
class expm_retraction():
    """the exmp retraction of a matrix Lie group
    this is the most general, and not efficient implementation
    for each lie group, we should have a custom implementation of this
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        return x@expm(jla.solve(x, v))

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """

        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(jla.solve(x, sigma(x, t, seq.reshape(x.shape)))))(jnp.eye(driver_dim)),
                            axis=0)

class cayley_so_retraction():
    """Cayley retraction of a matrix Lie group
    this is the most general, and not efficient implementation
    for each lie group, we should have a custom implementation of this
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        ixv = x.T@v
        return x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv)

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """
        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(self.mnf.sigma_id(seq.reshape(x.shape)))
                                 )(jnp.eye(driver_dim)),
                            axis=0)

def sqr(a):
  return a@a

## Test for $SO(n)$
We test the adjusted ito_drift is tangent, double check that it is -sum of gamma.

In [6]:
def test_expm_integrator_so():
    n = 5
    key = random.PRNGKey(0)
    so_dim = n*(n-1)//2
    metric_mat, key = rand_positive_definite(key, so_dim, (.1, 10.))
    mnf = som.SOLeftInvariant(n, metric_mat)
    x, key = mnf.rand_point(key)

    gsum = jnp.zeros((n, n))
    hsum = jnp.zeros((n, n))
    for i in range(n**2):
        nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n**2).at[i].set(1.).reshape(n, n)))
        hsum += x@sqr(x.T@nsg)
        gsum += - mnf.gamma(x, nsg, nsg)
        # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))

    print(f"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}")
    print(f"test adjusted ito is tangent={sym(x.T@(-0.5*hsum+mnf.ito_drift(x)))}")

    # now test the equation.
    # test Brownian motion
    def new_sigma(x, _, dw):
        return mnf.proj(x, mnf.sigma(x, dw))

    def mu(x, _):
        return mnf.ito_drift(x)

    pay_offs = [lambda x, t: t*jnp.sum(x*jnp.arange(n)),
                lambda x: jnp.sqrt(jnp.sum(x*x))]

    key, sk = random.split(key)
    t_final = 1.
    n_path = 1000
    n_div = 1000
    d_coeff = .5
    wiener_dim = n**2
    x_0 = jnp.eye(n)

    ret_geo = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.geodesic_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_ito = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_ito_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_str = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    rtr = expm_retraction(mnf)
    # a warm up run
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, 5, 5, d_coeff, wiener_dim])

    t0 = perf_counter()
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t1 = perf_counter()
    print('Time rtr %f' % (t1-t0))

    crtr = cayley_so_retraction(mnf)
    t4 = perf_counter()
    ret_crtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               crtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t5 = perf_counter()
    print('Time crtr %f' % (t5-t4))

    print(f"geo second order = {jnp.nanmean(ret_geo[0])}")
    print(f"Ito              = {jnp.nanmean(ret_ito[0])}")
    print(f"Stratonovich     = {jnp.nanmean(ret_str[0])}")
    print(f"Retractive       = {jnp.nanmean(ret_rtr[0])}")
    print(f"expm_so_Retractive       = {jnp.nanmean(ret_rtr[0])}")
    print(f"Cayley Retractive       = {jnp.nanmean(ret_crtr[0])}")

test_expm_integrator_so()

test sum -gamma - ito drift=[[ 3.46944695e-16 -6.93889390e-17 -5.55111512e-17  2.77555756e-17
  -1.04083409e-17]
 [-5.72458747e-17  3.60822483e-16  0.00000000e+00  2.77555756e-17
   2.22044605e-16]
 [-5.20417043e-17  2.22044605e-16 -5.55111512e-17  9.71445147e-17
   1.94289029e-16]
 [ 6.93889390e-17 -1.38777878e-17 -4.16333634e-17  1.94289029e-16
   1.11022302e-16]
 [-5.55111512e-17  2.77555756e-17  1.66533454e-16  4.16333634e-17
  -3.46944695e-17]]
test adjusted ito is tangent=[[-4.77257607e-17  6.50674636e-17  1.89911247e-16  1.44996565e-16
  -2.18644708e-17]
 [ 6.50674636e-17 -3.20764925e-17  8.53113921e-18 -1.27781329e-16
  -1.46461781e-16]
 [ 1.89911247e-16  8.53113921e-18 -1.96166997e-16 -1.96229575e-17
   5.03222025e-17]
 [ 1.44996565e-16 -1.27781329e-16 -1.96229575e-17 -5.22337541e-17
  -1.14971926e-16]
 [-2.18644708e-17 -1.46461781e-16  5.03222025e-17 -1.14971926e-16
  -3.93447870e-16]]
Time rtr 70.728604
Time crtr 14.665379
geo second order = 6.056695178129377
Ito            