$\newcommand{\sigmam}{\mathring{\sigma}}$
$\newcommand{\rgrad}{\mathsf{rgrad}}$
$\newcommand{\sfg}{\mathsf{g}}$
$\newcommand{\cI}{\mathcal{I}}$
$\newcommand{\R}{\mathbb{R}}$
$\newcommand{\fR}{\mathfrak{r}}$
# Simulating the Riemannian Langevin process to sample a distribution on a manifold with a specified density relative to the Riemannian measure on matrix Lie Groups.

The Langevin process is a solution of the equation
$$dX_t = (\frac{1}{2}rgrad_{\log V}(X_t) + \mu_B(X_t))dt +\sigmam(X_t) dW_t
$$
where $V$ is a smooth positive function on $M$ and
$$dB_t = \mu_B(B_t)dt +\sigmam(B_t) dW_t
$$
is the Riemannian Brownian motion of a metric $\sfg$. With smoothness and curvature conditions, this process converges to a distribution on the manifold with density relative to the Riemannian measure proportional to $V$.

We test the groups SO and SE in this workbook. First, install jax_rb

In [1]:
pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-ph2b9mtd
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-ph2b9mtd
  Resolved https://github.com/dnguyend/jax-rb to commit 829c06c2301ca7671986bc311b70b19267494bf9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jax_rb
  Building wheel for jax_rb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax_rb: filename=jax_rb-0.1.dev53+g829c06c-py3-none-any.whl size=33510 sha256=220c81000e3e5e7e676ecce151f8832014c1b3cd9166c4c51108fd4bda63ea3b
  Stored in directory: /tmp/pip-ephem-wheel-cache-6nyleohr/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490
Successfully built jax_rb
Installing collected packages: jax_rb
Successfully installed

Basic import. We also add a function to sample on a $SO(n)$ using polar retraction based on [Chikuse2003], to compare with the SDE simulation.
Thus, given a function $V$ as a density relative to the Riemannian measure, the expectation with respect to the measure defined by the density $V$ is
 $\frac{\int fVdvol_{\sfg}}{\int Vdvol_{\sfg}}$, where we will compute the two integrals by sampling using the function uniform_sample below

*Reference*\
Y. Chikuse, Statistics on Special Manifolds, Springer New York, NY,
New York, NY, USA, 2003.

In [2]:
""" test riemannian langevin for SO and SE
"""


from functools import partial

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import random, vmap, jit
import jax_rb.manifolds.so_left_invariant as som
import jax_rb.manifolds.se_left_invariant as sem

from jax_rb.utils.utils import (rand_positive_definite, sym, vcat, grand)
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.matrix_group_integrator as mi


jax.config.update("jax_enable_x64", True)


def sqr(x):
    return x@x


def cz(mat):
    return jnp.max(jnp.abs(mat))

def uniform_sample(key, shape, pay_off, n_samples):
    """ Sample the manifold uniformly
    """
    x_all, key = grand(key, (shape[0], shape[1], n_samples))

    def do_one_point(seq):
        # ei, ev = jla.eigh(seq.T@seq)
        # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))
        u, _, vt = jla.svd(seq)
        return pay_off(u[:, :shape[0]]@vt)

    s = jax.vmap(do_one_point, in_axes=2)(x_all)
    return jnp.nanmean(s)





# Sampling on the special orthogonal group $SO(n)$.
Besides the three main simulation methods, we also use the drift-adjust method for the polar decomposition. As in the case of the Stiefel manifold, the polar decomposition has the second order derivative
$$\fR^{(2)}(x, 0, v,v )= -x v^Tv = x (x^Tv)^2$$
This terms offsets the Ito drift of the RB gradient term in $\mu_{\fR}$ (we can see this at $x=I_n$, then translate to any point). Thus, we only have the gradient component in $\mu_{\fR}$.

The density is proprtional to $e^{-\Lambda_0 x^T\Lambda_1 x}$ with a diagonal $\Lambda_0$. The function to take expectation is $f(x) = (1+\sum |\lambda_{ij}x_{ij}|)^{\frac{1}{2}}$.


In [3]:
class cayley_so_retraction():
    """Cayley retraction of a matrix Lie group
    this is the most general, and not efficient implementation
    for each lie group, we should have a custom implementation of this
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        ixv = x.T@v
        return x + x@jla.solve(jnp.eye(ixv.shape[0]) - 0.5*ixv, ixv)

    def inverse_retract(self, x, y):
        u = x.T@y
        n = self.mnf.shape[0]
        return 2*x@jla.solve(jnp.eye(n)+u, u-jnp.eye(n))

    def drift_adjust(self, x, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """
        return -0.5*jnp.sum(vmap(lambda seq:
                                 x@sqr(self.mnf.sigma_id(seq.reshape(x.shape)))
                                 )(jnp.eye(driver_dim)),
                            axis=0)



def test_langevin_so():
    # test Langevin on se(n) with vfunc = e^{-\frac{1}{2}v^T\Lambda v}
    # jax.config.update('jax_default_device', jax.devices('cpu')[0])
    n = 4
    so_dim = n*(n-1)//2

    lbd = [0.5*jnp.diag(jnp.arange(1, n+1)), 0.5*jnp.arange(1, n**2+1).reshape(n, n)]

    def log_v(_, x):
        return -jnp.trace(lbd[0]@x.T@lbd[1]@x)

    def grad_log_v(mnf, x):
        return -mnf.proj(x, mnf.inv_g_metric(
            x, (lbd[1]+lbd[1].T)@x@lbd[0]))

    key = random.PRNGKey(0)

    metric_mat, key = rand_positive_definite(key, so_dim, (.1, 30.))

    print("Doing SO")

    # metric_mat = jnp.eye(se_dim)
    so = som.SOLeftInvariant(n, metric_mat)
    crtr = cayley_so_retraction(so)
    x, key = so.rand_point(key)
    eta, key = so.rand_vec(key, x)

    # print(jax.jvp(lambda x: log_v(so, x), (x,), (eta,))[1])
    # print(so.inner(x, eta, grad_log_v(so, x)))

    # x1 = crtr.retract(x, eta)
    # eta1 = crtr.inverse_retract(x, x1)
    # print(cz(eta1-eta))


    pay_offs = [None, lambda x: jnp.sqrt(1+jnp.sum(jnp.abs(x)))]

    lbd1, key = grand(key, (n**2,))
    pay_offs = [None, lambda x: jnp.sqrt(1+jnp.sum(jnp.abs(lbd1*x.reshape(-1))))]

    x_0 = jnp.eye(n)

    key, sk = random.split(key)
    t_final = 20.
    # t_final = 1.5
    n_path = 1000
    n_div = 1000
    d_coeff = .5

    wiener_dim = n**2

    ret_rtr = sim.simulate(x_0,
                           lambda x, unit_move, scale: crtr.retract(
                               x,
                               x@so.sigma_id(unit_move.reshape(x.shape))*scale**.5
                               + 0.5*grad_log_v(so, x)*scale),
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("SO Cayley retract %.3f" % jnp.nanmean(ret_rtr[0]))


    ret_rtr1 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.ito_move_with_drift(
                                so, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(so, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Ito Langevin %.3f" % jnp.nanmean(ret_rtr1[0]))

    ret_rtr2 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.stratonovich_move_with_drift(
                                so, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(so, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Stratonovich Langevin %.3f" % jnp.nanmean(ret_rtr2[0]))


    ret_rtr3 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.geodesic_move_with_drift(
                                so, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(so, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Geodesic 2nd order langevin %.3f" % jnp.nanmean(ret_rtr3[0]))

    # Using known method to sample Stiefel manifold uniformly with homogenous measure
    n_samples = 1000**2
    ret_denom = uniform_sample(
        key, so.shape,
        lambda x: jnp.exp(log_v(so, x)),
        n_samples)

    ret_num = uniform_sample(
        key, so.shape,
        lambda x: pay_offs[1](x)*jnp.exp(log_v(so, x)),
        n_samples)
    print("SO sampling with density %.3f" % (ret_num/ret_denom))

test_langevin_so()


Doing SO
SO Cayley retract 2.841
Ito Langevin 2.841
Stratonovich Langevin 2.841
Geodesic 2nd order langevin 2.841
SO sampling with density 2.800


Sampling on $SE(n)$. The group $SE(n)$ is not compact. For a left-invariant metric, the volume form is $\det(\cI)^{\frac{1}{2}}dvol_I$, where $I$ is the standard (product) measure of $SO(n)\times \R^n$. For the density $V$, we take the function $V(U, v) = e^{-\frac{1}{2}v^T\Lambda v}$ for an element $(U, v)\in SE(n)$ for a diagonal matrix $\Lambda$.

For the first test we take $n=3$ and $f=\sum |x_{ij}|$ where we identify $SE(n)$ with a subgroup of $GL(n+1)$.

In [4]:

def test_langevin_se():
    # test Langevin on se(n) with vfunc = e^{-\frac{1}{2}v^T\Lambda v}
    # jax.config.update('jax_default_device', jax.devices('cpu')[0])
    n = 3
    lbd = 10.*jnp.arange(1, n+1)

    @partial(jit, static_argnums=(0,))
    def log_v(_, x):
        return -0.5*jnp.sum(x[:-1, -1]*lbd*x[:-1, -1])

    @partial(jit, static_argnums=(0,))
    def grad_log_v(mnf, x):
        return mnf.proj(x, mnf.inv_g_metric(
            x,
            jnp.zeros_like(x).at[:-1, -1].set(-lbd*x[:-1, -1])))

    key = random.PRNGKey(0)

    se_dim = n*(n+1)//2
    n1 = n+1
    metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))

    # convergent seems to be to same metric, but different rate

    # metric_mat = jnp.eye(se_dim)
    # metric_mat = metric_mat.at[0, 0].set(1.)
    se = sem.SELeftInvariant(n, metric_mat)
    # x, key = se.rand_point(key)
    # eta, key = se.rand_vec(key, x)

    # print(jax.jvp(lambda x: log_v(se, x), (x,), (eta,))[1])
    # print(se.inner(x, grad_log_v(se, x), eta))

    # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x[:-1, -1]**2))]

    # pay_offs = [None, lambda x: jnp.sum(jnp.abs(x[:-1, -1]))]

    # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x))]
    print("Test SE with n=%d expectation of sum |x|" % (n))
    pay_offs = [None, lambda x: jnp.sum(jnp.abs(x))]

    x_0 = jnp.eye(n1)
    key, sk = random.split(key)
    t_final = 20.
    n_path = 5000
    n_div = 1000
    d_coeff = .5

    wiener_dim = n1**2
    ret_rtr1 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.ito_move_with_drift(
                                se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Ito Langevin %.3f" % jnp.nanmean(ret_rtr1[0]))

    ret_rtr2 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.stratonovich_move_with_drift(
                                se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Stratonovich Langevin %.3f" % jnp.nanmean(ret_rtr2[0]))


    ret_rtr3 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.geodesic_move_with_drift(
                                se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Geodesic 2nd order langevin %.3f" % jnp.nanmean(ret_rtr3[0]))


    def se_sample(key, shape, pay_off, n_samples):
        """ Sample the manifold uniformly on the sphere
        and with the
        """
        x_all, key = grand(key, (shape[0]-1, shape[1], n_samples))

        def do_one_point(seq):
            # ei, ev = jla.eigh(seq.T@seq)
            # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))
            u, _, vt = jla.svd(seq[:, :-1])
            x = vcat(jnp.concatenate(
                [u@vt, seq[:, -1][:, None]], axis=1),
                     jnp.zeros((1, shape[1])).at[0, -1].set(1.))
            return pay_off(x)*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2))

        s = jax.vmap(do_one_point, in_axes=2)(x_all)
        return jnp.nanmean(s)

    n_samples = 1000**2
    ret_denom = se_sample(
        key, se.shape,
        lambda x: 1.,
        n_samples)
    ret_num = se_sample(
        key, se.shape,
        pay_offs[1],
        n_samples)

    print("uniform sampling with density %.3f" % (ret_num/ret_denom))


test_langevin_se()



Test SE with n=3 expectation of sum |x|
Ito Langevin 6.076
Stratonovich Langevin 6.076
Geodesic 2nd order langevin 6.076
uniform sampling with density 6.075


For the second test, we take $n=4$, with the same density, and take expectation of the function $vec(x)^TAvec(x)$, where we vectorize the first $n$ rows of $x\in SE(n)$, identified with an element of $\R^{(n+1)\times(n+1)}$, for a randomly generated matrix $A$ of size $n(n+1)\times n(n+1)$.

In [5]:

def test_langevin_se2():
    # jax.config.update('jax_default_device', jax.devices('cpu')[0])
    n = 4
    se_dim = n*(n+1)//2
    n1 = n+1

    lbd = 10. + jnp.arange(n)

    @partial(jit, static_argnums=(0,))
    def log_v(_, x):
        return -0.5*jnp.sum(x[:-1, -1]*lbd*x[:-1, -1])

    @partial(jit, static_argnums=(0,))
    def grad_log_v(mnf, x):
        return mnf.proj(x, mnf.inv_g_metric(
            x,
            jnp.zeros_like(x).at[:-1, -1].set(-lbd*x[:-1, -1])))

    key = random.PRNGKey(0)

    # metric_mat, key = rand_positive_definite(key, se_dim, (.1, 30.))
    A, key = grand(key, (n*n1,n*n1))
    A = sym(A@A.T)
    # convergent seems to be to same metric, but different rate

    metric_mat = jnp.eye(se_dim)
    # metric_mat = metric_mat.at[0, 0].set(1.)
    se = sem.SELeftInvariant(n, metric_mat)
    # x, key = se.rand_point(key)
    # eta, key = se.rand_vec(key, x)

    # print(jax.jvp(lambda x: log_v(se, x), (x,), (eta,))[1])
    # print(se.inner(x, grad_log_v(se, x), eta))
    print("Test SE n=%d expectation of  |x^TAx|^(1/2) for a positive definite matrix A" % (n))

    pay_offs = [None, lambda x: jnp.sqrt(jnp.abs(jnp.sum(x[:-1, :].reshape(-1)*(A@x[:-1, :].reshape(-1)))))]
    # pay_offs = [None, lambda x: jnp.sqrt(jnp.sum(x*x*jnp.arange(1, n1+1)[None, :]))]
    # pay_offs = [None, lambda x: jnp.sum(x[0, :-1]*x[:-1, -1])]

    x_0 = jnp.eye(n1)
    key, sk = random.split(key)
    t_final = 20.
    n_path = 5000
    n_div = 1000
    d_coeff = .5

    wiener_dim = n1**2

    ret_rtr1 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.ito_move_with_drift(
                                se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Ito Langevin %.3f" % jnp.nanmean(ret_rtr1[0]))


    ret_rtr2 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.stratonovich_move_with_drift(
                                se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Stratonovich Langevin %.3f" % jnp.nanmean(ret_rtr2[0]))


    ret_rtr3 = sim.simulate(x_0,
                            lambda x, unit_move, scale: mi.geodesic_move_with_drift(
                                se, x, unit_move, scale, 0.5*jla.solve(x, grad_log_v(se, x))),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Geodesic 2nd order langevin %.3f" % jnp.nanmean(ret_rtr3[0]))


    def se_sample(key, shape, pay_off, n_samples):
        """ Sample the manifold uniformly on the sphere
        and with the
        """
        x_all, key = grand(key, (shape[0]-1, shape[1], n_samples))

        def do_one_point(seq):
            # ei, ev = jla.eigh(seq.T@seq)
            # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))
            u, _, vt = jla.svd(seq[:, :-1])
            x = vcat(jnp.concatenate(
                [u@vt, seq[:, -1][:, None]], axis=1),
                     jnp.zeros((1, shape[1])).at[0, -1].set(1.))
            return pay_off(x)*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2))
        #return jnp.sqrt(3+jnp.sum(x[:-1, -1]**2))*jnp.exp(log_v(se, x)+0.5*jnp.sum(x[:-1, -1]**2))

        s = jax.vmap(do_one_point, in_axes=2)(x_all)
        # ret = []
        # for i in range(x_all.shape[2]):
        #    ret.append(do_one_point(x_all[:, :, i]))
        # s = jnp.array(ret)
        return jnp.nanmean(s)

    n_samples = 1000**2

    ret_denom = se_sample(
        key, se.shape,
        lambda x: 1.,
        n_samples)
    """
    ret_num = se_sample(
        key, se.shape,
        lambda x: pay_offs[1](x),
        n_path*500)
    """
    ret_num = se_sample(
        key, se.shape,
        # lambda x: x[1, -1]*x[1, -1],
        pay_offs[1],
        # lambda x: pay_offs[1](x) - jnp.sqrt(3+jnp.sum(x[:-1, -1]**2)),
        n_samples)

    print("uniform sampling with density %.3f" % (ret_num/ret_denom))


test_langevin_se2()

Test SE n=4 expectation of  |x^TAx|^(1/2) for a positive definite matrix A
Ito Langevin 9.289
Stratonovich Langevin 9.295
Geodesic 2nd order langevin 9.288
uniform sampling with density 9.303
