<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/tests/notebooks/LangevinStiefel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

$\newcommand{\sigmam}{\mathring{\sigma}}$
$\newcommand{\egrad}{\mathsf{egrad}}$
$\newcommand{\rgrad}{\mathsf{rgrad}}$
$\newcommand{\sfg}{\mathsf{g}}$
$\newcommand{\cI}{\mathcal{I}}$
$\newcommand{\R}{\mathbb{R}}$
$\newcommand{\fR}{\mathfrak{r}}$

# Simulating Riemannian Langevin equations on Stiefel manifolds
The long-time limit of the Riemannian Langevin process allows us to sample a distribution on a manifold with a specified density relative to the Riemannian measure.

The Langevin process is a solution of the equation
$$dX_t = (\frac{1}{2}rgrad_{\log V}(X_t) + \mu_B(X_t))dt +\sigmam(X_t) dW_t
$$
where $V$ is a smooth function positive function on $M$ and
$$dB_t = \mu_B(B_t)dt +\sigmam(B_t) dW_t
$$
is the Riemannian Brownian motion of a metric $\sfg$. With smoothness and curvature conditions, this process converges to a distribution on the manifold with density relative to the Riemannian measure proportional to $V$.

We test the von Mises-Fisher and Bingham distributions in this workbook. First, install jax_rb:

In [1]:
pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-lo4kwtj9
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-lo4kwtj9
  Resolved https://github.com/dnguyend/jax-rb to commit 829c06c2301ca7671986bc311b70b19267494bf9
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jax_rb
  Building wheel for jax_rb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax_rb: filename=jax_rb-0.1.dev53+g829c06c-py3-none-any.whl size=33510 sha256=bb66af84c0cce2991645eabc97adfd8a2e9d37e3f94cecc2577e8daf64fa696f
  Stored in directory: /tmp/pip-ephem-wheel-cache-rl611tqh/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490
Successfully built jax_rb
Installing collected packages: jax_rb
Successfully installed

# Basic imports and helper functions.

We can sample the Stiefel manifold with the homogeneous measure uniformly by polar decomposition.

On another note, consider the polar decomposition as a retraction. We can simulate an Ito process by applying an adjustment to the drift. Class stiefel_polar_retraction implement the polar retraction, and provide the adjustment for the drift.

In [2]:
""" test riemannian langevin for stiefel manifolds
"""

from functools import partial
from time import perf_counter

import jax
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import random, vmap, jit
"""
"""
from jax.scipy.linalg import expm
import jax_rb.manifolds.stiefel as stm

from jax_rb.utils.utils import (sym, grand)
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.global_manifold_integrator as gmi


jax.config.update("jax_enable_x64", True)


def sqr(x):
    return x@x


def cz(mat):
    return jnp.max(jnp.abs(mat))


class stiefel_polar_retraction():
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the manifold
        """
        u, _, vt = jla.svd(x+v, full_matrices=False)
        return u@vt

    def drift_adjust(self, x):
        n, d, alp1 = self.mnf.shape[0], self.mnf.shape[1], self.mnf.alpha[1]
        return 0.5*(n-d+0.5*(d-1)/alp1)*x


def uniform_sampling(key, shape, pay_off, n_samples):
    """ Sample the manifold uniformly
    """
    x_all, key = grand(key, (shape[0], shape[1], n_samples))

    def do_one_point(seq):
        # ei, ev = jla.eigh(seq.T@seq)
        # return pay_off(seq@ev@((1/jnp.sqrt(ei))[:, None]*ev.T))
        u, _, vt = jla.svd(seq, full_matrices=False)
        return pay_off(u[:, :shape[0]]@vt)

    s = jax.vmap(do_one_point, in_axes=2)(x_all)
    return jnp.nanmean(s)

def gen_sym_traceless(key, n):
    """ Generating a traceless symmetric matrix
    """
    A, key = grand(key, (n, n))
    return sym(A) - jnp.trace(A)/n*jnp.eye(n), key




# Test von Mises-Fisher
For the von Mises-Fisher the density is proportional to $V(x) = e^{\kappa Tr (M^Tx)}$. In this case, $\egrad_{\log V}(x)=\kappa M$, and $\rgrad_{\log V}(x) = \kappa\Pi(x)g^{-1}M$, this additional drift is used in the _with_drift methods in the library.

For each test, given a function $f$, we run the 3 integration methods: Ito, Stratonovich and geodesic, which is a second order retraction of the projected drift. Long term, these integration methods are supposed to converge to the von Mises-Fisher distribution.  The average simulated value $f(X_t)$ should converge to $E(f)$. We verify that they are consistent, and agree with  $\frac{\int fVdvol}{\int Vdvol}$, where each integral is computed by uniform sampling.

For some cases, we test two different functions $f$. All examples show good agreement.

In [3]:

def test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, func, t_final, n_path, n_div, n_samples=1000**2):
    # test Langevin on stiefel with vfunc = e^{-\frac{1}{2}v^T\Lambda v}
    # jax.config.update('jax_default_device', jax.devices('cpu')[0])
    print("Doing Stiefel von Mises Fisher (n, d)=%s alpha=%s" % (str(stf.shape), str(stf.alpha)))

    @partial(jit, static_argnums=(0,))
    def log_v(_, x):
        return kp*jnp.trace(M.T@x)

    @partial(jit, static_argnums=(0,))
    def grad_log_v(mnf, x):
        return kp*mnf.proj(x, mnf.inv_g_metric(x, M))

    x, key = stf.rand_point(key)
    eta, key = stf.rand_vec(key, x)

    # print(jax.jvp(lambda x: log_v(stf, x), (x,), (eta,))[1])
    # print(stf.inner(x, grad_log_v(stf, x), eta))

    pay_offs = [None, func]

    x_0, key = stf.rand_point(key)
    key, sk = random.split(key)
    # t_final = 5.
    # n_path = 10000
    # n_div = 500
    d_coeff = .5

    wiener_dim = stf.shape[0]*stf.shape[1]
    # crtr = cayley_se_retraction(se)

    # rbrownian_ito_langevin_move(mnf, x, unit_move, scale, grad_log_v)
    ret_rtr1 = sim.simulate(x_0,
                            lambda x, unit_move, scale: gmi.ito_move_with_drift(
                                stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("ito langevin %.3f" % jnp.nanmean(ret_rtr1[0]))

    ret_rtr2 = sim.simulate(x_0,
                            lambda x, unit_move, scale: gmi.stratonovich_move_with_drift(
                                stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("stratonovich langevin %.3f" % jnp.nanmean(ret_rtr2[0]))

    ret_rtr3 = sim.simulate(x_0,
                            lambda x, unit_move, scale: gmi.geodesic_move_with_drift(
                                stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("geodesic langevin %.3f" % jnp.nanmean(ret_rtr3[0]))


    ret_spl =  uniform_sampling(key, stf.shape,
                              lambda x: pay_offs[1](x)*jnp.exp(log_v(None, x)),
                              n_samples)

    ret_spl_0 =  uniform_sampling(key, stf.shape,
                                lambda x: jnp.exp(log_v(None, x)),
                                n_samples)

    print("stiefel uniform sampling with density %.3f" % (ret_spl/ret_spl_0))
    # import scipy.special as ss
    # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5))


def test_all_stiefel_von_mises_fisher():
    n = 3
    d = 1

    alp = jnp.array([1, 1.])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)

    # F, key = stf.rand_point(key)
    kp = 1.
    M, key = stf.rand_point(key)
    test_stiefel_langevin_von_mises_fisher(
        key, stf, kp, M,
        lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)), t_final=5., n_path=10000, n_div=500, n_samples=1000**2)
    # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5))


    n = 5
    d = 3
    alp = jnp.array([1, .6])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)

    # F, key = stf.rand_point(key)
    kp = 1.2
    M, key = stf.rand_point(key)

    test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)),
                                          t_final=5., n_path=10000, n_div=500, n_samples=1000**2)

    test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sum(jnp.abs(x)), t_final=5.,
                                           n_path=10000, n_div=500,n_samples=1000**2)

    n = 5
    d = 3
    alp = jnp.array([1, 1.])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)

    # F, key = stf.rand_point(key)
    kp = 1.2
    M, key = stf.rand_point(key)

    test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)),
                                          t_final=5., n_path=10000, n_div=500, n_samples=1000**2)

    test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sum(jnp.abs(x)), t_final=5.,
                                           n_path=10000, n_div=500,n_samples=1000**2)


test_all_stiefel_von_mises_fisher()
import scipy.special as ss
print("Exact value for m=3 d=1 val=%f" % (jnp.sqrt(2)*ss.iv(1, 1.)/ss.iv(.5, 1.)*ss.gamma(1.5)))

Doing Stiefel von Mises Fisher (n, d)=(3, 1) alpha=[1. 1.]
ito langevin 0.756
stratonovich langevin 0.756
geodesic langevin 0.755
stiefel uniform sampling with density 0.754
Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1.  0.6]
ito langevin 0.885
stratonovich langevin 0.885
geodesic langevin 0.888
stiefel uniform sampling with density 0.882
Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1.  0.6]
ito langevin 5.629
stratonovich langevin 5.627
geodesic langevin 5.629
stiefel uniform sampling with density 5.624
Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1. 1.]
ito langevin 0.883
stratonovich langevin 0.881
geodesic langevin 0.885
stiefel uniform sampling with density 0.882
Doing Stiefel von Mises Fisher (n, d)=(5, 3) alpha=[1. 1.]
ito langevin 5.633
stratonovich langevin 5.631
geodesic langevin 5.633
stiefel uniform sampling with density 5.624
Exact value for m=3 d=1 val=0.755402


A taller example. $(n, d)=(27,2)$.

In [4]:
def test_tall_stiefel_von_mises_fisher():
    n = 27
    d = 2
    alp = jnp.array([1, .6])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)

    # F, key = stf.rand_point(key)
    kp = 1.2
    M, key = stf.rand_point(key)

    test_stiefel_langevin_von_mises_fisher(key, stf, kp, M, lambda x: jnp.sqrt(jnp.abs(1-jnp.trace(M.T@x)**2)),
                                           t_final=5., n_path=10000, n_div=500, n_samples=10000)

    test_stiefel_langevin_von_mises_fisher(key, stf, kp, M,
                                           lambda x: jnp.sum(jnp.abs(x)),
                                           t_final=5., n_path=10000, n_div=500, n_samples=10000)

test_tall_stiefel_von_mises_fisher()


Doing Stiefel von Mises Fisher (n, d)=(27, 2) alpha=[1.  0.6]
ito langevin 0.957
stratonovich langevin 0.956
geodesic langevin 0.956
stiefel uniform sampling with density 0.957
Doing Stiefel von Mises Fisher (n, d)=(27, 2) alpha=[1.  0.6]
ito langevin 8.372
stratonovich langevin 8.372
geodesic langevin 8.370
stiefel uniform sampling with density 8.363


# Test Bingham
The density is $e^{Trx^TAx}$, thus the Riemannian gradient is $2\Pi\sfg^{-1}A x$, where $A$ is a traceless matrix.

In [5]:
def test_stiefel_langevin_bingham(key, stf, A, func):
    # test Langevin on stiefel with vfunc = e^{-\frac{1}{2}v^T\Lambda v}
    # jax.config.update('jax_default_device', jax.devices('cpu')[0])
    @partial(jit, static_argnums=(0,))
    def log_v(_, x):
        return jnp.trace(x.T@A@x)

    @partial(jit, static_argnums=(0,))
    def grad_log_v(mnf, x):
        return mnf.proj(x, mnf.inv_g_metric(x, 2*A@x))

    print("Doing Bingham (n, d)=%s alpha=%s" % (str(stf.shape), str(stf.alpha)))

    # x, key = stf.rand_point(key)
    # eta, key = stf.rand_vec(key, x)

    # print(jax.jvp(lambda x: log_v(stf, x), (x,), (eta,))[1])
    # print(stf.inner(x, grad_log_v(stf, x), eta))

    pay_offs = [None, func]

    x_0, key = stf.rand_point(key)
    key, sk = random.split(key)
    t_final = 5.
    n_path = 10000
    n_div = 500
    d_coeff = .5

    wiener_dim = stf.shape[0]*stf.shape[1]
    # crtr = cayley_se_retraction(se)

    # rbrownian_ito_langevin_move(mnf, x, unit_move, scale, grad_log_v)
    ret_rtr1 = sim.simulate(x_0,
                            lambda x, unit_move, scale: gmi.ito_move_with_drift(
                                stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("ito langevin %.3f" % jnp.nanmean(ret_rtr1[0]))

    ret_rtr2 = sim.simulate(x_0,
                            lambda x, unit_move, scale: gmi.stratonovich_move_with_drift(
                                stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("stratonovich langevin %.3f" % jnp.nanmean(ret_rtr2[0]))

    ret_rtr3 = sim.simulate(x_0,
                            lambda x, unit_move, scale: gmi.geodesic_move_with_drift(
                                stf, x, unit_move, scale, 0.5*grad_log_v(stf, x)),
                            pay_offs[0],
                            # lambda x: x[1, -1]*x[1, -1],
                            pay_offs[1],
                            [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("geodesic langevin %.3f" % jnp.nanmean(ret_rtr3[0]))

    n_samples = 1000**2
    ret_spl =  uniform_sampling(key, stf.shape,
                              lambda x: pay_offs[1](x)*jnp.exp(log_v(None, x)),
                              n_samples)

    ret_spl_0 =  uniform_sampling(key, stf.shape,
                                lambda x: jnp.exp(log_v(None, x)),
                                n_samples)

    print("stiefel uniform sampling with density %.3f" % (ret_spl/ret_spl_0))

def test_all_bingham():
    n = 3
    d = 1
    alp = jnp.array([1, .6])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)

    A, key = gen_sym_traceless(key, n)
    test_stiefel_langevin_bingham(
        key, stf, A,
        lambda x: jnp.sum(jnp.abs(x)))
    # print(jnp.sqrt(2)*ss.iv(1, 1)/ss.iv(.5, 1)*ss.gamma(1.5))

    n = 5
    d = 3
    alp = jnp.array([1, .6])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)
    A, key = gen_sym_traceless(key, n)
    test_stiefel_langevin_bingham(
        key, stf, A,
        lambda x: jnp.sum(jnp.abs(x)))

    test_stiefel_langevin_bingham(
        key, stf, A,
        lambda x: jnp.sum(jnp.abs(x)*(A@jnp.abs(x))))

    n = 7
    d = 3
    alp = jnp.array([1, .6])
    key = random.PRNGKey(0)

    stf = stm.RealStiefelAlpha((n, d), alp)
    A, key = gen_sym_traceless(key, n)
    test_stiefel_langevin_bingham(
        key, stf, A,
        lambda x: jnp.sum(jnp.abs(x)))

    test_stiefel_langevin_bingham(
        key, stf, A,
        lambda x: jnp.sum(jnp.abs(x)*(A@jnp.abs(x))))


test_all_bingham()

Doing Bingham (n, d)=(3, 1) alpha=[1.  0.6]
ito langevin 1.494
stratonovich langevin 1.495
geodesic langevin 1.493
stiefel uniform sampling with density 1.497
Doing Bingham (n, d)=(5, 3) alpha=[1.  0.6]
ito langevin 5.613
stratonovich langevin 5.617
geodesic langevin 5.613
stiefel uniform sampling with density 5.616
Doing Bingham (n, d)=(5, 3) alpha=[1.  0.6]
ito langevin -0.559
stratonovich langevin -0.571
geodesic langevin -0.555
stiefel uniform sampling with density -0.565
Doing Bingham (n, d)=(7, 3) alpha=[1.  0.6]
ito langevin 6.551
stratonovich langevin 6.551
geodesic langevin 6.550
stiefel uniform sampling with density 6.553
Doing Bingham (n, d)=(7, 3) alpha=[1.  0.6]
ito langevin -0.563
stratonovich langevin -0.567
geodesic langevin -0.548
stiefel uniform sampling with density -0.579


We also run the drift-adjust method for the polar decomposition. The second order expansion of the Polar decompostion is given as
$$\fR(x, hv) = x+ hv-\frac{h^2}{2}xv^{T}v + O(h^3)
$$
This is because if $x + hv = U\Sigma$ is a polar deconposition then since $x^Tv$ is antisymmetric,
$$(x + hv)^T(x + hv)=
I_p + h^2v^Tv = \Sigma^2,$$ or
$$U = (x + hv)(I_p + h^2v^Tv)^{-\frac{1}{2}} = x + hv -\frac{h^2}{2}xv^Tv+O(h^3).$$
We can check in this case, the second order adjustment
$$-\frac{1}{2}\sum_{ij}\fR^{(2)}(x,0, \Pi\sfg^{-1}E_{ij}, \Pi\sfg^{-1}E_{ij}) $$
cancels the Riemannian-Brown Ito drift, leaving just the density contribution.
Results are below, showing consistency of the methods.

In [6]:
def drift_adjust_verify(self, x, sigma, wiener_dim):
    """return the adjustment :math:`\\mu_{adj}`
    so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
    """
    def sqt(a):
        return a.T@a

    return -0.5*x@jnp.sum(vmap(lambda seq:
                               -sqt(self.proj(x, sigma(x, seq.reshape(x.shape)))))(jnp.eye(wiener_dim)),
                          axis=0)


def test_polar_retract_adjust():
    n = 7
    d = 3
    alp = jnp.array([1, .6])
    key = random.PRNGKey(0)
    stf = stm.RealStiefelAlpha((n, d), alp)
    print("Doing Stiefel Polar retract for Bingham (n, d)=%s alpha=%s" % (str(stf.shape), str(stf.alpha)))
    @partial(jit, static_argnums=(0,))
    def log_v(_, x):
        return jnp.trace(x.T@A@x)

    @partial(jit, static_argnums=(0,))
    def grad_log_v(mnf, x):
        return mnf.proj(x, mnf.inv_g_metric(x, 2*A@x))

    x, key = stf.rand_point(key)

    # mu2 = -0.5*(n-d+0.5*(d-1)/alp[1])*x
    prtr = stiefel_polar_retraction(stf)

    mu1 = drift_adjust_verify(stf, x, stf.sigma, n*d)
    mu2 = prtr.drift_adjust(x)
    # print("compare drift adjust %s" % str(mu2 - mu1))
    print("compare drift adjust %s" % str(mu1 + stf.ito_drift(x)))

    A, key = gen_sym_traceless(key, n)

    x_0, key = stf.rand_point(key)
    pay_offs = [None, lambda x: jnp.sum(jnp.abs(x))]

    key, sk = random.split(key)
    t_final = 5.
    n_path = 10000
    n_div = 500
    d_coeff = .5

    wiener_dim = stf.shape[0]*stf.shape[1]

    test_stiefel_langevin_bingham(key, stf, A, pay_offs[1])

    @jax.jit
    def polar_adj(x, unit_move, scale):
      return prtr.retract(x, stf.proj(x, stf.sigma(x, unit_move.reshape(x.shape)*scale**.5
                                                + scale*(
                                                      # prtr.drift_adjust(x)
                                                      # + stf.ito_drift(x)
                                                      + 0.5*grad_log_v(stf, x)))))
    ret_rtr = sim.simulate(x_0,
                           polar_adj,
                           pay_offs[0],
                           pay_offs[1],
                           [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

    print("Polar adjust %.3f" % jnp.nanmean(ret_rtr[0]))


test_polar_retract_adjust()


Doing Stiefel Polar retract for Bingham (n, d)=(7, 3) alpha=[1.  0.6]
compare drift adjust [[ 4.44089210e-16 -3.33066907e-16 -5.55111512e-17]
 [ 0.00000000e+00  1.66533454e-16  2.22044605e-16]
 [-5.55111512e-17  4.44089210e-16  2.22044605e-16]
 [ 1.11022302e-16 -3.33066907e-16  0.00000000e+00]
 [ 0.00000000e+00 -6.66133815e-16 -2.22044605e-16]
 [ 0.00000000e+00  5.55111512e-17 -1.11022302e-16]
 [-2.22044605e-16 -1.11022302e-15  1.11022302e-16]]
Doing Bingham (n, d)=(7, 3) alpha=[1.  0.6]
ito langevin 6.562
stratonovich langevin 6.564
geodesic langevin 6.563
stiefel uniform sampling with density 6.564
Polar adjust 6.557
