<a href="https://colab.research.google.com/github/dnguyend/jax-rb/blob/main/tests/notebooks/TestRetractiveIntegrator.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# TESTING THE RETRACTIVE INTEGRATOR
  * Since the package is not yet on pypi, use the dialog box below. Otherwise, on a terminal, download the repository then install locally.
  

  We show it step by step here, for other groups we will run one python script in folder tests (eg python test_so.py).


In [1]:
#@title Imports & Utils
import ipywidgets as widgets
from IPython.display import display
import subprocess


class credentials_input():
    """To access a private repository
    Include this snippet of codes to colab if you want to access
    a private repository
    """
    def __init__(self, repo_name):
        self.repo_name = repo_name
        self.username = widgets.Text(description='Username', value='')
        self.pwd = widgets.Password(
            description='Password', placeholder='password here')

        self.username.on_submit(self.handle_submit_username)
        self.pwd.on_submit(self.handle_submit_pwd)
        display("Use %40 for @ in email address:")
        display(self.username)

    def handle_submit_username(self, text):
        display(self.pwd)

    def handle_submit_pwd(self, text):
        username = self.username.value.replace('@', '%40')
        #  cmd = f'git clone https://{username}:{self.pwd.value}@{self.repo_name}'
        cmd = f'pip install git+https://{username}:{self.pwd.value}@{self.repo_name}'
        process = subprocess.Popen(
            cmd.split(), stdout=subprocess.PIPE, stderr=subprocess.PIPE)
        output, error = process.communicate()
        print(output, error)
        self.username.value, self.pwd.value = '', ''

credentials_input('github.com/dnguyend/jax-rb.git')




'Use %40 for @ in email address:'

Text(value='', description='Username')

<__main__.credentials_input at 0x7dc0702dfd90>

Password(description='Password', placeholder='password here')



## The manifold
The manifold is defined by one equation of the form $C(x) = \sum_i d_ix_i^p=1$ with the embedded metric.
* Brownian motion, thus, could be simulated with the integrator using nearest point retraction
* We show it could also be simulated with the integrator using the rescaling retraction.
$$\mathfrak{r}(x, v) = C(x+v)^{1/p}(x+v)
$$

* Basic functionality of the class DiagHypersurface is tested in the test folder (tests/test_diag_hypersurface.py). We test the integrator here.
* We can verify for $x=(x_i)_{i=1}^n, v=(v_i)_{i=1}^n$
$$\mathfrak{r}(x, tv) = x + tv + \frac{(1-p)t^2(\sum d_ix_i^{p-2}v_i^2)}{2}x+O(t^3)
$$
The Ito adjustment is in class rtr, equal to
$$-\frac{(1-p)t^2(\sum_{ij} d_ix_i^{p-2}(\sigma e_j)_i^2)}{2}x
$$

In [2]:
import jax
import jax.numpy as jnp
from jax import random, jvp, vmap

from jax_rb.manifolds.diag_hypersurface import DiagHypersurface
from jax_rb.utils.utils import (grand)
import jax_rb.simulation.simulator as sim
import jax_rb.simulation.global_manifold_integrator as gmi
import jax_rb.simulation.retractive_integrator as rmi
jax.config.update("jax_enable_x64", True)

## The rescaling retraction

In [3]:
class rescale_retraction():
    """the rescaling retraction on
    diagonal constrained hypersurface
    """
    def __init__(self, mnf):
        self.mnf = mnf

    def retract(self, x, v):
        """rescaling :math:`x+v` to be on the hypersurface
        """
        val = self.mnf.cfunc(x+v)
        return (x+v)/val**(1/self.mnf.p)

    def hess(self, x, v):
        """hessian of the rescaling
        """
        p = self.mnf.p
        dvec = self.mnf.dvec
        return (1-p)*x*jnp.sum(dvec*x**(p-2)*v*v)

    def drift_adjust(self, sigma, x, t, driver_dim):
        """return the adjustment :math:`\\mu_{adj}`
        so that :math:`\\mu + \\mu_{adj} = \\mu_{\\mathfrak{r}}`
        """
        return -0.5*jnp.sum(vmap(lambda seq:
                                 self.hess(x, sigma(x, t, seq)))(jnp.eye(driver_dim)),
                            axis=0)


## Test the retraction has the required properties

In [4]:
n = 5
p = 2
key = random.PRNGKey(0)
dvec, key = grand(key, (n,))
dvec = dvec.at[-1].set(1.)

mnf = DiagHypersurface(dvec, p)
x, key = mnf.rand_point(key)
# now test retract
while True:
    q, key = mnf.rand_ambient(key)
    if mnf.cfunc(q) > 0:
        xq = mnf.approx_nearest(q)
        break
print(f"test apprx nearest C(q)={mnf.cfunc(q)}, C(x)={mnf.cfunc(xq)}")

# now tangent.
xi, key = mnf.rand_vec(key, x)
rtr = rescale_retraction(mnf)
v = .01*xi
x1 = rtr.retract(x, v)
print(f"test retract C(rtr.retract(x, v)={mnf.cfunc(x1)}")

def rt(t):
    return rtr.retract(x, t*v)

def dr(t):
    p = rtr.mnf.p
    cft = rtr.mnf.cfunc(x+t*v)
    return -1/p*cft**(-1-1/p)*jnp.sum(rtr.mnf.grad_c(x+t*v)*v)*(x+t*v) \
        + cft**(-1/p)*v

print("test deriv and hess of retract")
print(jvp(rt, (.1,), (1.,))[1])
print(dr(.1))
print(jvp(dr, (0.,), (1.,))[1])
print(rtr.hess(x, v))

gsum = jnp.zeros(n)
hsum = jnp.zeros(n)
for i in range(n):
    nsg = mnf.proj(x, mnf.sigma(x, jnp.zeros(n).at[i].set(1.)))
    hsum += -rtr.hess(x, nsg)
    gsum += - mnf.gamma(x, nsg, nsg)
    # print(jnp.sum(mnf.grad_c(x)*(hsum-gsum)))

print(f"test sum -gamma - ito drift={0.5*gsum - mnf.ito_drift(x)}")
print(f"test adjusted ito is tangent={jnp.sum(mnf.grad_c(x)*(0.5*hsum+mnf.ito_drift(x)))}")

# now test the equation.
# test Brownian motion



test apprx nearest C(q)=0.11321789016690081, C(x)=1.0000000000000002
test retract C(rtr.retract(x, v)=1.0000000000000002
test deriv and hess of retract
[ 0.00024948  0.02365371 -0.00926304  0.01827491  0.00802626]
[ 0.00024948  0.02365371 -0.00926304  0.01827491  0.00802626]
[-1.21630899e-05  2.07281857e-04 -5.59875159e-05  1.11632152e-04
  1.39241311e-04]
[-1.21630899e-05  2.07281857e-04 -5.59875159e-05  1.11632152e-04
  1.39241311e-04]
test sum -gamma - ito drift=[-1.73472348e-18 -3.46944695e-18  0.00000000e+00  0.00000000e+00
  0.00000000e+00]
test adjusted ito is tangent=-5.967448757360216e-16


# check the stratonovich and the ito drift given in the library is the same as the summation in the main theorem.

In [5]:
def new_sigma(x, _, dw):
    return mnf.proj(x, mnf.sigma(x, dw))

def mu(x, _):
    return mnf.ito_drift(x)

pay_offs = [lambda x, t: t*jnp.sum(x*jnp.arange(n)),
            lambda x: jnp.sum(x*x)]

key, sk = random.split(key)
t_final = 1.
n_path = 1000
n_div = 1000
d_coeff = .5
wiener_dim = n
x_0 = jnp.zeros(n).at[-1].set(1)

ret_geo = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.geodesic_move(
                            mnf, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

ret_ito = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.rbrownian_ito_move(
                            mnf, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

ret_str = sim.simulate(x_0,
                        lambda x, unit_move, scale: gmi.rbrownian_stratonovich_move(
                            mnf, x, unit_move, scale),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

ret_rtr = sim.simulate(x_0,
                        lambda x, unit_move, scale: rmi.retractive_move(
                            rtr, x, None, unit_move, scale, new_sigma, mu),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

ret_nrtr = sim.simulate(x_0,
                        lambda x, unit_move, scale: rmi.retractive_move_normalized(
                            rtr, x, None, unit_move, scale, new_sigma, mu),
                        pay_offs[0],
                        pay_offs[1],
                        [sk, t_final, n_path, n_div, d_coeff, wiener_dim])

print(f"geo second order = {jnp.nanmean(ret_geo[0])}")
print(f"Ito              = {jnp.nanmean(ret_ito[0])}")
print(f"Stratonovich     = {jnp.nanmean(ret_str[0])}")
print(f"Retractive       = {jnp.nanmean(ret_rtr[0])}")
print(f"Retractive Norm. = {jnp.nanmean(ret_nrtr[0])}")



geo second order = 8.616377076786248
Ito              = 8.617933533688669
Stratonovich     = 8.594460387538478
Retractive       = 8.628692203861034
Retractive Norm. = 8.65157317775276
