<a href="https://colab.research.google.com/github/dnguyend/MiscCollection/blob/main/SDE/colab/StiefelSDEwithPolar.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

$\newcommand{\R}{\mathbb{R}}$
# Solving SDE on Stiefel manifolds using polar retraction.

In this note, we illustrate the result of [NS24] for the case of a SDE on a Stiefel manifold. The idea is that of feasibility enhancement: We simulate the SDE using a method on the ambient manifold, then enhance to make sure the next iterative point is on the manifold. For a Stiefel manifold, this is can be done using different retractions. We illustrate the result with a polar decomposition, however, we could also try qr decomposition or Caley transforms.

We demonstrate the invariant condition of [MPS16] is satisfied for Riemannian Brownian motions on Stiefel manifolds.

We use [jax_rb](https://github.com/dnguyend/jax-rb), which has the Riemannian Brownian motions for Stiefel manifolds built in, but we make explicit a few construction to illustrate.

[MPS16] Goran Marjanovic, Marc J. Piggott, and Victor Solo. Numerical
methods for stochastic differential equations in the Stiefel man-
ifold made simple. In 2016 IEEE 55th Conference on Decision
and Control (CDC), pages 2853–2860, 2016.

[NS24] D. Nguyen and S Sommer. Second-order differential operators,
stochastic differential equations and brownian motions on em-
bedded manifolds. arXiv:2406.02879, 2024  

Install the required libraries

In [1]:
pip install git+https://github.com/dnguyend/jax-rb

Collecting git+https://github.com/dnguyend/jax-rb
  Cloning https://github.com/dnguyend/jax-rb to /tmp/pip-req-build-5uo2jm_o
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/jax-rb /tmp/pip-req-build-5uo2jm_o
  Resolved https://github.com/dnguyend/jax-rb to commit 20efd03c04d80b3438f32dcbf48cd917036675b4
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Installing backend dependencies ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: jax_rb
  Building wheel for jax_rb (pyproject.toml) ... [?25l[?25hdone
  Created wheel for jax_rb: filename=jax_rb-0.1.dev50+g20efd03-py3-none-any.whl size=33136 sha256=e492ec1117fd5f7aea647df31745c1641e5fd6a88002be634dc5ea0095d04df4
  Stored in directory: /tmp/pip-ephem-wheel-cache-u6juldpw/wheels/0f/76/88/65e675f8bcca47be98c588d9a787a4c1c9b0a5044517ba6490
Successfully built jax_rb
Insta

In [2]:
import jax
import jax.numpy as jnp

import jax.numpy.linalg as jla

from jax import random
from jax_rb.manifolds import RealStiefelAlpha
from jax_rb.utils.utils import (sym, asym)
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.global_manifold_integrator as gmi


jax.config.update("jax_enable_x64", True)

* Define a Stiefel manifold, do some sanity testing

In [3]:
key = random.PRNGKey(0)

n = 5
p = 3

al = jnp.array([1, .7])
stf = RealStiefelAlpha((n, p), al)
x, key = stf.rand_point(key)
v, key = stf.rand_vec(key, x)
print(sym(x.T@v))
print(x.T@x)


[[-2.50590117e-17  5.55111512e-17  8.32667268e-17]
 [ 5.55111512e-17  1.56473573e-17 -1.94289029e-16]
 [ 8.32667268e-17 -1.94289029e-16  4.09287851e-16]]
[[1.00000000e+00 5.20502955e-18 8.33607223e-17]
 [5.20502955e-18 1.00000000e+00 6.92237803e-17]
 [8.33607223e-17 6.92237803e-17 1.00000000e+00]]


* Define the polar decomposition, test its second order Taylor expansion

In [4]:
def polar(q):
  u, s, vt = jla.svd(q, full_matrices=False)
  return u@vt

t = 1e-3
polar(x + t*v) - x - t*v +t**2/2*x@v.T@v

Array([[-3.76551720e-10, -9.67271525e-10,  2.06162951e-09],
       [ 4.15987350e-10, -4.85275453e-10,  1.75741327e-09],
       [ 2.47405992e-09, -1.33228770e-09,  4.21068089e-09],
       [-2.68620603e-09, -1.81381986e-10,  7.56495339e-10],
       [ 2.87122156e-09,  3.51414721e-10, -1.12513947e-09]],      dtype=float64)

$\newcommand{\fR}{\mathfrak{r}}$
Compute tha adjustment $\frac{1}{2}\fR^{(2)}(x, \sigma w_j, \sigma w_j)$, for $w$ runs over a basis of the ambient space. We show this is exactly the ito drift. This is special for Stiefel manifolds and polar decomposition. In general their difference is tangent at $x$, but not necessarily zero.

In [5]:
# ??stf.ito_drift

def sig(x, dw):
  return dw - x@(x.T@dw) + 1/jnp.sqrt(al[1])*x@asym(x.T@dw)


zr = jnp.zeros(n*p)

def calc_adj(x):
  s = jnp.zeros((n, p))
  for i in range(n*p):
      sigw = sig(x, zr.at[i].set(1.).reshape(stf.shape))
      s += -0.5*x@sigw.T@sigw
  return s

def calc_ito(x):
  s = jnp.zeros((n, p))
  for i in range(n*p):
      sigw = sig(x, zr.at[i].set(1.).reshape(stf.shape))
      s += -0.5*stf.gamma(x, sigw, sigw)
  return s

print(stf.ito_drift(x))

print(calc_adj(x))
print(calc_ito(x))

[[ 1.38410769  0.0091622   0.63139119]
 [ 0.03131142  1.56713737 -0.55783334]
 [-0.58695641 -0.03351799  0.09550478]
 [ 0.14888281 -0.69348529 -1.28610466]
 [ 0.80954129  0.02695806 -0.75216773]]
[[ 1.38410769  0.0091622   0.63139119]
 [ 0.03131142  1.56713737 -0.55783334]
 [-0.58695641 -0.03351799  0.09550478]
 [ 0.14888281 -0.69348529 -1.28610466]
 [ 0.80954129  0.02695806 -0.75216773]]
[[ 1.38410769  0.0091622   0.63139119]
 [ 0.03131142  1.56713737 -0.55783334]
 [-0.58695641 -0.03351799  0.09550478]
 [ 0.14888281 -0.69348529 -1.28610466]
 [ 0.80954129  0.02695806 -0.75216773]]


We now consider two values of $\alpha$, with the small value $\alpha =.1$ and large value of $\alpha=20$. We demonstrate the stochastic projection method (ret_io), the geodesic approximation (ret_geo, using theorem 4) and the adjusted drift method ( above, with $\mu_{\fR} = \mu - \mu_{\alpha}= 0$ as we have just mentioned) are consistent with each other for each value of $\alpha$

In [9]:
pay_offs = [None,
            lambda x: jnp.sum(jnp.abs(x[0, 0]))]

al_small = jnp.array([1, .1])
stf_small = RealStiefelAlpha((n, p), al_small)

key, sk = random.split(key)
t_final = .5
n_path = 1000
n_div = 1000
d_coeff = .5
wiener_dim = n*p
x_0 = jnp.zeros((n, p)).at[:p, :].set(jnp.eye(p))

ret_ito_small = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.rbrownian_ito_move(
                            stf_small, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
print(jnp.mean(ret_ito_small[0]))

ret_geo_small = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.geodesic_move(
                            stf_small, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
print(jnp.mean(ret_geo_small[0]))

0.3831278524121285
0.3845773931666296


In [14]:
class polar_retraction():
    """the polar retraction
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the hypersurface
        """
        return polar(x+v)

    def hess(self, x, v):
        """hessian of the rescaling
        """
        return -x@v.T@v

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\mu_{adj}`
        so that :math:`\mu + \mu_{adj} = \mu_{\mathfrak{r}}`
        """
        return -0.5*jnp.sum(jax.vmap(lambda seq:
                                 self.hess(x, sigma(x, t, seq.reshape(self.mnf.shape))))(jnp.eye(driver_dim)),
                            axis=0)


po = polar_retraction(stf_small)
print(po.drift_adjust_detail(lambda x, _, dw: stf_small.proj(x, stf_small.sigma(x, dw.reshape(stf_small.shape))),
                  x, None, jnp.prod(jnp.array(stf_small.shape))))
print(stf_small.ito_drift(x))

[[-4.8443769  -0.03206771 -2.20986916]
 [-0.10958997 -5.48498079  1.9524167 ]
 [ 2.05434744  0.11731296 -0.33426672]
 [-0.52108982  2.4271985   4.5013663 ]
 [-2.83339452 -0.09435322  2.63258704]]
[[ 4.8443769   0.03206771  2.20986916]
 [ 0.10958997  5.48498079 -1.9524167 ]
 [-2.05434744 -0.11731296  0.33426672]
 [ 0.52108982 -2.4271985  -4.5013663 ]
 [ 2.83339452  0.09435322 -2.63258704]]


We can use retractive_move in simulation.retractive_integrator, but the sum of $\fR^{(0)}$ over a basis of $\R^{n\times p}$ is slow, but we note $\mu_{\fR}=0$, so we simply remove it, retracting only the stochastic component. The result is consistent with the ito and geodesic (second-order retraction) above.

In [17]:
ret_rtr_small = sim.simulate(x_0,
                        lambda x, unit_move, scale: po.retract(
                            x, stf_small.proj(x, stf_small.sigma(x, (unit_move*jnp.sqrt(scale)).reshape(stf_small.shape)))),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

jnp.mean(ret_rtr_small[0])

Array(0.38316787, dtype=float64)

* Now for large $\alpha = 20.$ Again the three retractions are consistent.

In [19]:
al_large = jnp.array([1, 20.])
stf_large = RealStiefelAlpha((n, p), al_large)

ret_ito_large = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.rbrownian_ito_move(
                            stf_large, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
print(jnp.mean(ret_ito_large[0]))


ret_geo_large = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.geodesic_move(
                            stf_large, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])
print(jnp.mean(ret_geo_large[0]))

0.6150533614932605
0.6112091477472982


In [21]:
ret_rtr_large = sim.simulate(x_0,
                        lambda x, unit_move, scale: po.retract(
                            x, stf_large.proj(x, stf_large.sigma(x, (unit_move*jnp.sqrt(scale)).reshape(stf_large.shape)))),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

print(jnp.mean(ret_rtr_large[0]))

0.6153144701990799
