<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/tests/notebooks/TestExpmAndCayleyIntegrator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The Lie group exponential map and its second order approximations
  * Since the package is not yet on pypi, use the dialog box below. Otherwise, on a terminal, download the repository then install locally.
  
  

In [None]:
!pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-oa2a9hi1
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-oa2a9hi1
  Resolved https://github.com/dnguyend/jax-rb to commit 20efd03c04d80b3438f32dcbf48cd917036675b4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jax_rb
  Building wheel for jax_rb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax_rb: filename=jax_rb-0.1.dev50+g20efd03-py3-none-any.whl size=33135 sha256=f24800b7a2d206c9e978d3abd86bf29f7eca9b81150400954a41b4ac54b236ec
  Stored in directory: /tmp/pip-ephem-wheel-cache-lgolyt48/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490
Successfully built jax_rb
Installing collected packages: jax_rb
Successfully installed

$\newcommand{\R}{\mathbb{R}}$
$\newcommand{\fR}{\mathfrak{r}}$
$\newcommand{\so}{\mathfrak{so}}$
$\newcommand{\expm}{\mathsf{expm}}$
$\newcommand{\sigmam}{\mathring{\sigma}}$

## The exponential retraction and the Cayley transform retraction.
For a matrix Lie group $G$, the exponential retraction is given by
$$\fR(x, v) = x\expm(x^{-1}v)
$$
where $v\in T_xG$. We have the expansion
$$\fR(x, hv) = x + hv + \frac{h^2}{2}x(x^{-1}v)^2 + O(h^3)
$$
Thus, the adjusted drift will be $\mu_{\fR} = \mu - \frac{1}{2}x(x^{-1}\sigmam(x; E_{ij}))^2$ for the equation:
$$dX = \mu dt + \sigmam(X)dW_t.$$
For a random move $\Delta_W \sim N(0, h^{\frac{1}{2}}I_{\R^{N\times N}})$, the  Euler-Maruyama exponential step will be
$$X_{i+1} = X_i\expm(X_i^{-1}(h\mu_r(X_i)+ \sigmam(X)\Delta_W ))
$$
It is a [remarkable fact](https://en.wikipedia.org/wiki/Pad%C3%A9_approximant#cite_note-wolfram-alpha-pade-exp-11)  that the diagonal Pade approximator of $e^x$ is a rational function of the form $\frac{p(x)}{p(-x)}$, with the first order approximation corresponds to $p(x) = 1+\frac{x}{2}$.

For the group SO(N), or (more generally, for quadratic Lie group [Celledoni and Iserle]), for $a\in \so(N)$, we have $p(-a)^{-1}p(a)$ is in $SO(N)$ for all analytic $p$ with real coefficients. With $p(x) = 1+\frac{x}{2}$, we have the Cayley retraction

$$\fR_{Cayley}(x, v) = x(I-\frac{1}{2}x^{-1}v)^{-1}(I+\frac{1}{2}x^{-1}v),
$$
and the Euler-Maruyama Cayle steps will be
$$X_{i+1} = X_i(I - \frac{1}{2}X_i^{-1}(h\mu_r(X_i)+ \sigmam(X)\Delta_W ))^{-1}
(I + \frac{1}{2}X_i^{-1}(h\mu_r(X_i)+ \sigmam(X)\Delta_W ))
$$
We will check that these steps give the same simulation results as the geodesic, Ito and Stratonovich integrator in the paper.


### References

[Celledoni and Iserle] Celledoni, Elena, and Arieh Iserles. “Approximating the Exponential from a Lie Algebra to a Lie Group.” Mathematics of Computation, vol. 69, no. 232, 2000, pp. 1457–80. JSTOR, http://www.jstor.org/stable/2585076. Accessed 17 July 2024.


In [None]:
!pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-_uzzt7oc
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-_uzzt7oc
  Resolved https://github.com/dnguyend/jax-rb to commit 20efd03c04d80b3438f32dcbf48cd917036675b4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone


In [None]:
"""test simulation with the
"""
from functools import partial
from time import perf_counter

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import random, vmap, jit
from jax.scipy.linalg import expm
import jax_rb.manifolds.so_left_invariant as som
import jax_rb.manifolds.se_left_invariant as sem
import jax_rb.manifolds.affine_left_invariant as afm

from jax_rb.utils.utils import (rand_positive_definite, sym, vcat)
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.matrix_group_integrator as mi


jax.config.update("jax_enable_x64", True)


## A retractive step

In [None]:
@partial(jit, static_argnums=(0,2,5,6))
def matrix_retractive_move(rtr, x, t, unit_move, scale, sigma, mu):
    """ Simulating the equation :math:`dX_t = \\mu(X_t, t) dt + \\sigma(X_t, t) dW_t` using the retraction rtr. The manifold is a Lie group.
    We do not assume a Riemanian metric on the manifold, :math:`\\sigma\\sigma^T` could be degenerated on :math:`T\\mathcal{M}`. However, we create subclasses for left-invariant Lie groups.

    W is a Wiener process driving the equation, defined on :math:`\\mathbb{R}^k`. W is given by unit_move.

    :math:`\\sigma(X_t, t)` maps :math:`\\mathbb{R}^k` to :math:`\\mathcal{E}`, but the image belongs
    to :math:`T_{X_t}\\mathcal{M}`.

    The retraction rtr is assume to have the method :math:`\\text{drift_adj}` for an adjustment.

    The move is :math:`x_{new} = \\mathfrak{r}(x, \\Pi(x)\\sigma(x)(\\text{unit_move}(\\text{scale})^{\\frac{1}{2}}) + \\text{scale} (\\mu + \\text{drift_adj}))`.

    :param rtr: the retraction,
    :param x: a point on the manifold,
    :param t: time
    :param unit_move: a random normal draw
    :param scale: scaling
    :param sigma: a function implementing the map :math:`\\sigma`
    :param mu: a function implementing the Ito drift :math:`\\mu`
    """
    return rtr.retract(x,
                       sigma(x, t, unit_move.reshape(x.shape))*jnp.sqrt(scale)
                       + scale*(mu(x, t) + rtr.drift_adjust(sigma, x, t, unit_move.shape[0])))


A few classes for retraction on groups and specialized implementations for $SO(n)$ and $SE(n)$. Here, we are in the case of Riemannian Brownian.

In [None]:
class expm_retraction():
    """the exmp retraction of a matrix Lie group
    this is the most general, and not efficient implementation
    for each lie group, we should have a custom implementation of this
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        return x@expm(jla.solve(x, v))

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """

        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(jla.solve(x, sigma(x, t, seq.reshape(x.shape)))))(jnp.eye(driver_dim)),
                            axis=0)

class cayley_so_retraction():
    """Cayley retraction of a matrix Lie group
    this is the most general, and not efficient implementation
    for each lie group, we should have a custom implementation of this
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        ixv = x.T@v
        return x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv)

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """
        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(self.mnf.sigma_id(seq.reshape(x.shape)))
                                 )(jnp.eye(driver_dim)),
                            axis=0)

class cayley_se_retraction():
    """Cayley retraction of a matrix Lie group
    this is the most general, and not efficient implementation
    for each lie group, we should have a custom implementation of this
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        n = x.shape[0] - 1
        ixva = x[:-1, :-1].T@v[:-1, :-1]
        return vcat(jnp.concatenate([x[:-1, :-1] + x[:-1, :-1]@jla.solve(jnp.eye(n)-0.5*ixva, ixva),
                                     jla.solve(jnp.eye(n)-0.5*ixva, v[:-1, n:])], axis=1),
                    jnp.zeros(x.shape[0]).at[-1].set(1.).reshape(1, -1))
    # x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv)

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """
        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(self.mnf.sigma_id(seq.reshape(x.shape)))
                                 )(jnp.eye(driver_dim)),
                            axis=0)

def sqr(a):
    return a@a


## Test for $SO(n)$
We test the adjusted ito_drift is tangent, double check that it is -sum of gamma.

In [None]:
def test_expm_integrator_so():
    n = 5
    key = random.PRNGKey(0)
    so_dim = n*(n-1)//2
    metric_mat, key = rand_positive_definite(key, so_dim, (.1, 10.))
    mnf = som.SOLeftInvariant(n, metric_mat)
    x, key = mnf.rand_point(key)

    gsum = jnp.zeros((n, n))
    hsum = jnp.zeros((n, n))
    for i in range(n**2):
        nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n**2).at[i].set(1.).reshape(n, n)))
        hsum += x@sqr(x.T@nsg)
        gsum += - mnf.gamma(x, nsg, nsg)
        # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))

    print(f"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}")
    print(f"test adjusted ito is tangent={sym(x.T@(-0.5*hsum+mnf.ito_drift(x)))}")

    # now test the equation.
    # test Brownian motion
    def new_sigma(x, _, dw):
        return mnf.proj(x, mnf.sigma(x, dw))

    def mu(x, _):
        return mnf.ito_drift(x)

    pay_offs = [lambda x, t: t*jnp.sum(x*jnp.arange(n)),
                lambda x: jnp.sqrt(jnp.sum(x*x))]

    key, sk = random.split(key)
    t_final = 1.
    n_path = 1000
    n_div = 1000
    d_coeff = .5
    wiener_dim = n**2
    x_0 = jnp.eye(n)

    ret_geo = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.geodesic_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_ito = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_ito_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_str = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    rtr = expm_retraction(mnf)
    # a warm up run
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, 5, 5, d_coeff, wiener_dim])

    t0 = perf_counter()
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t1 = perf_counter()
    print('Time rtr %f' % (t1-t0))

    crtr = cayley_so_retraction(mnf)
    t4 = perf_counter()
    ret_crtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               crtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t5 = perf_counter()
    print('Time crtr %f' % (t5-t4))

    print(f"geo second order = {jnp.nanmean(ret_geo[0])}")
    print(f"Ito              = {jnp.nanmean(ret_ito[0])}")
    print(f"Stratonovich     = {jnp.nanmean(ret_str[0])}")
    print(f"Retractive       = {jnp.nanmean(ret_rtr[0])}")
    print(f"expm_so_Retractive       = {jnp.nanmean(ret_rtr[0])}")
    print(f"Cayley Retractive       = {jnp.nanmean(ret_crtr[0])}")

test_expm_integrator_so()

test sum -gamma - ito drift=[[-3.26128013e-16  2.08166817e-16  1.66533454e-16  8.32667268e-17
   6.24500451e-17]
 [-7.97972799e-17 -8.32667268e-17 -1.11022302e-16 -1.80411242e-16
  -3.60822483e-16]
 [ 3.12250226e-17 -3.60822483e-16  6.93889390e-17 -2.77555756e-17
  -2.22044605e-16]
 [-8.32667268e-17 -2.08166817e-16  1.24900090e-16 -2.77555756e-17
  -2.49800181e-16]
 [ 0.00000000e+00 -5.55111512e-17 -1.52655666e-16  4.85722573e-17
   2.08166817e-16]]
test adjusted ito is tangent=[[-2.37661530e-16 -2.22460559e-17 -1.90646870e-16 -7.17538281e-17
   1.26459054e-16]
 [-2.22460559e-17  3.04724698e-16  3.08194244e-17  1.62580828e-17
   1.29192211e-16]
 [-1.90646870e-16  3.08194244e-17  5.47579078e-17  1.45219969e-16
   6.05938757e-17]
 [-7.17538281e-17  1.62580828e-17  1.45219969e-16 -1.71952758e-17
   5.51827559e-17]
 [ 1.26459054e-16  1.29192211e-16  6.05938757e-17  5.51827559e-17
   4.28030441e-16]]
Time rtr 73.857657
Time crtr 20.281238
geo second order = 6.056695178129371
Ito            

Test SE. The integrator for SE

In [None]:
def test_expm_integrator_se():
    n = 3
    key = random.PRNGKey(0)
    se_dim = n*(n+1)//2
    metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))
    mnf = sem.SELeftInvariant(n, metric_mat)
    x, key = mnf.rand_point(key)
    n1 = n+1

    gsum = jnp.zeros((n1, n1))
    hsum = jnp.zeros((n1, n1))
    for i in range(n1**2):
        nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n1**2).at[i].set(1.).reshape(n1, n1)))
        hsum += x@sqr(jla.solve(x, nsg))
        gsum += - mnf.gamma(x, nsg, nsg)
        # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))

    print(f"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}")
    print(f"test adjusted ito is tangent={sym(x.T@(-0.5*hsum+mnf.ito_drift(x)))}")

    # now test the equation.
    # test Brownian motion

    def new_sigma(x, _, dw):
        return mnf.proj(x, mnf.sigma(x, dw))

    def mu(x, _):
        return mnf.ito_drift(x)

    pay_offs = [lambda x, t: t*jnp.maximum(x[0, 0]-.5, 0),
                lambda x: x[0, 0]**2]

    key, sk = random.split(key)
    t_final = 1.
    n_path = 1000
    n_div = 1000
    d_coeff = .5
    wiener_dim = n1**2
    x_0 = jnp.eye(n1)

    ret_geo = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.geodesic_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_ito = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_ito_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_str = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    rtr = expm_retraction(mnf)
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, 5, 5, d_coeff, wiener_dim])

    t0 = perf_counter()
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t1 = perf_counter()
    print('Time rtr %f' % (t1-t0))

    crtr = cayley_se_retraction(mnf)
    t4 = perf_counter()
    ret_crtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               crtr, x, 1., unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t5 = perf_counter()
    print('Time crtr %f' % (t5-t4))

    print(f"geo second order = {jnp.nanmean(ret_geo[0])}")
    print(f"Ito              = {jnp.nanmean(ret_ito[0])}")
    print(f"Stratonovich     = {jnp.nanmean(ret_str[0])}")
    print(f"Retractive       = {jnp.nanmean(ret_rtr[0])}")
    print(f"Cayley Retractive       = {jnp.nanmean(ret_crtr[0])}")
test_expm_integrator_se()


test sum -gamma - ito drift=[[ 1.17961196e-16  9.02056208e-17  7.63278329e-17  2.77555756e-17]
 [-5.55111512e-17 -4.94396191e-17 -6.93889390e-17 -3.64291930e-17]
 [ 9.02056208e-17  7.28583860e-17 -1.38777878e-16  2.77555756e-17]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]
test adjusted ito is tangent=[[ 1.28024894e-16  9.54199437e-17 -2.23648658e-17 -2.70499180e-17]
 [ 9.54199437e-17  5.49101560e-17  2.59732528e-17 -1.82677302e-17]
 [-2.23648658e-17  2.59732528e-17  1.49963480e-16 -7.62769047e-18]
 [-2.70499180e-17 -1.82677302e-17 -7.62769047e-18 -1.42003740e-16]]
Time rtr 42.457243
Time crtr 14.400974
geo second order = 1.1363965505646139
Ito              = 1.136390973605797
Stratonovich     = 1.1363671560250506
Retractive       = 1.1363802045207878
Cayley Retractive       = 1.1363867289743548


# A simple approximation of expm
 The two-terms Taylor series. For most groups, this does not work, but for the affine group and the generalized linear group, this works.

In [None]:
class expm_apprx_retraction():
    """the a retractive approximation of expm. This is simply a Taylor expansion
    it works for affine group and GL(n), but the second Taylor expansion
    in general is not a retraction. The other type is pade
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        return x + v + 0.5*x@sqr(jla.solve(x, v))

    def drift_adjust(self, _, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """

        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(self.mnf.sigma_id(seq.reshape(x.shape))))(jnp.eye(driver_dim)),
                            axis=0)


Again, test the adjusted ito is tangent, then show the expm and the two terms taylor series simulations give the same result as the other 3 simulations

In [None]:
def test_expm_integrator_affine():
    n = 3
    aff_dim = n*(n+1)
    n1 = n + 1

    key = random.PRNGKey(0)
    metric_mat, key = rand_positive_definite(key, aff_dim, (.1, 10.))
    mnf = afm.AffineLeftInvariant(n, metric_mat)

    x, key = mnf.rand_point(key)

    gsum = jnp.zeros((n1, n1))
    hsum = jnp.zeros((n1, n1))
    for i in range(n1**2):
        nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n1**2).at[i].set(1.).reshape(n1, n1)))
        hsum += x@sqr(jla.solve(x, nsg))
        gsum += - mnf.gamma(x, nsg, nsg)
        # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))

    print(f"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}")
    print(f"test adjusted ito is tangent={jla.solve(x, (-0.5*hsum+mnf.ito_drift(x)))}")

    # now test the equation.
    # test Brownian motion

    def new_sigma(x, _, dw):
        return mnf.proj(x, mnf.sigma(x, dw))

    def mu(x, _):
        return mnf.ito_drift(x)

    pay_offs = [lambda x, t: t*jnp.maximum(x[0, 0]-.5, 0),
                lambda x: (1+jnp.abs(x[0, 0]))**(-.5)
                ]


    key, sk = random.split(key)
    t_final = 1.
    n_path = 1000
    n_div = 200
    d_coeff = .5
    wiener_dim = n1**2
    x_0 = jnp.eye(n1)

    ret_geo = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.geodesic_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_ito = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_ito_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    ret_str = sim.simulate(x_0,
                           lambda x, unit_move, scale: mi.rbrownian_stratonovich_move(
                               mnf, x, unit_move, scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    rtr = expm_retraction(mnf)
    t0 = perf_counter()
    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               rtr, x, None, unit_move, scale, new_sigma, mu),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t1 = perf_counter()
    print('Time rtr %f' % (t1-t0))

    artr = expm_apprx_retraction(mnf)
    t2 = perf_counter()
    ret_artr = sim.simulate(x_0,
                           lambda x, unit_move, scale: matrix_retractive_move(
                               artr, x, None, unit_move, scale, new_sigma, mu),
                            pay_offs[0],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
    t3 = perf_counter()
    print('Time artr %f' % (t3-t2))

    print(f"geo second order = {jnp.nanmean(ret_geo[0])}")
    print(f"Ito              = {jnp.nanmean(ret_ito[0])}")
    print(f"Stratonovich     = {jnp.nanmean(ret_str[0])}")
    print(f"Retractive       = {jnp.nanmean(ret_rtr[0])}")
    print(f"Appx Exp Retractive       = {jnp.nanmean(ret_artr[0])}")

test_expm_integrator_affine()


test sum -gamma - ito drift=[[ 3.57353036e-16 -8.88178420e-16  1.04083409e-16 -1.56125113e-16]
 [-2.08166817e-16 -3.46944695e-18  6.93889390e-17  1.87350135e-16]
 [ 3.95516953e-16 -3.74700271e-16 -3.98986399e-16 -2.84494650e-16]
 [ 0.00000000e+00  0.00000000e+00  0.00000000e+00  0.00000000e+00]]
test adjusted ito is tangent=[[-0.05771071  0.03133447 -0.01165553 -0.03045669]
 [ 0.03785821 -0.20563126  0.03272205  0.04153271]
 [ 0.0002595   0.04121875 -0.13468639  0.03533443]
 [ 0.          0.          0.          0.        ]]
Time rtr 11.436188
Time artr 2.475534
geo second order = 0.985876600067447
Ito              = 0.985415106027261
Stratonovich     = 0.9857234685704925
Retractive       = 0.9857389429799134
Appx Exp Retractive       = 0.9856962692286884
