<a href="https://colab.research.google.com/github/dnguyend/par-trans/blob/main/examples/JAXFlagParallel.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Workbook to test parallel transport on Flag manifolds using JAX

We can choose either CPU or GPU in the hardware setting. There is a time limit on GPU use on the free cloud so you may want to run GPU on your own machine.
We choose CPU here to avoid disruption when we hit time limit.

## We will use the Canonical metric $\alpha = \frac{1}{2}$
The Levi-Civita connection and geodesics work for all $\alpha$ but parallel transport only works for $\alpha=\frac{1}{2}$.


In [1]:
!pip install git+https://github.com/dnguyend/par-trans

Collecting git+https://github.com/dnguyend/par-trans
  Cloning https://github.com/dnguyend/par-trans to /tmp/pip-req-build-hzk3lzn2
  Running command git clone --filter=blob:none --quiet https://github.com/dnguyend/par-trans /tmp/pip-req-build-hzk3lzn2
  Resolved https://github.com/dnguyend/par-trans to commit 7e7ac7ffa5629925f55389d240cfd0b60c94b70f
  Installing build dependencies ... [?25l[?25hdone
  Getting requirements to build wheel ... [?25l[?25hdone
  Preparing metadata (pyproject.toml) ... [?25l[?25hdone
Building wheels for collected packages: par-trans
  Building wheel for par-trans (pyproject.toml) ... [?25l[?25hdone
  Created wheel for par-trans: filename=par_trans-0.1.dev13+g7e7ac7f-py3-none-any.whl size=32087 sha256=c1ca1169979a41c5bf822b4a93923b5b280afb1e6d6257ed3965ffc98e087d4a
  Stored in directory: /tmp/pip-ephem-wheel-cache-p4q6w4rk/wheels/43/54/27/33e127e64ed29c538e4b5356c5a5801c527ecfa659afe8ec3c
Successfully built par-trans
Installing collected packages: pa

Import the libraries

In [2]:
from time import perf_counter
import timeit
import pandas as  pd

import jax
import jax.numpy as jnp
from jax import jvp, random
from jax_par_trans.expv.utils import (cz, sym, asym)
from jax_par_trans.manifolds import Flag
from jax_par_trans.manifolds import Stiefel


Run several tests:
* check the metric compatible condition of the Levi-Civita connection

* cz is $\max\circ abs$, used to check if a matrix is zero


In [3]:
jax.config.update("jax_enable_x64", True)
dvec = jnp.array([5, 2, 3])
alp = .5
flg = Flag(dvec, alp)
key = random.PRNGKey(0)

x, key = flg.rand_point(key)
v, key = flg.rand_vec(key, x)
va, key = flg.rand_vec(key, x)


dlt = 1e-6
t = .8

print("CHECK THAT THE CHRISTOFFEL FUNCTION gives a connection. Checking the covariant derivative is horizontal")

r1 = jvp(lambda t: flg.proj(x+t*v, va), (0.,), (1.,))[1] + flg.christoffel_gamma(x, v, va)
print(cz(sym(x.T@r1)))
print(cz(flg.proj_m(sym(x.T@r1)) - asym(x.T@r1)))


CHECK THAT THE CHRISTOFFEL FUNCTION gives a connection. Checking the covariant derivative is horizontal
1.0547118733938987e-15
1.1084883011491797e-15


# Check the covariant derivative is metric compatible
For 2 tangent vectors $v, va$, $X: z\mapsto flg.proj(z, va)$ is a vector field. We compare $D_v\langle X, X\rangle$ and $2\langle X, D_vX + \Gamma(v, X)\rangle$

In [None]:

print("CHECK THAT THE Covariant derivative is metric compatible ")
# print((stf.inner(x+dlt*v, va, va) - stf.inner(x, va, va))/dlt)
print(jax.jvp(lambda z: flg.inner(z, flg.proj(z, va), flg.proj(z, va)), (x,), (v,))[1])
print(2*flg.inner(x, va,
                  jax.jvp(lambda x: flg.proj(x, va), (x,), (v,))[1]
                  + flg.christoffel_gamma(x, v, va)))




CHECK THAT THE Covariant derivative is metric compatible 
1.521056092249815
1.521056092249814


Now check parallel transport. The function flg.exp gives the Riemannian exponential map, while $flg.dexp(x, v, t,ddexp=True) $ computes $\gamma(t), \dot{\gamma}(t)$ and $\ddot{\gamma}(t)$ where $\gamma(t)$ is the geodesic starting at $x$ with initial velocity $\dot{\gamma}(0) = v$. If $ddexp=False$, only $\gamma$ and $\dot{\gamma}$ are returned.

We verify dexp indeed returns the time derivatives, and show the parallel transport equation is satisfied

In [None]:
r1 = flg.exp(x, t*v)

print(cz(jvp(lambda t: flg.exp(x, t*v), (t,), (1.,))[1]
          - flg.dexp(x, v, t, ddexp=False)[1]))

print(cz(jvp(lambda t: flg.dexp(x, v, t)[1], (t,), (1.,))[1]
          - flg.dexp(x, v, t, ddexp=True)[2]))

gmms = flg.dexp(x, v, t, ddexp=True)
print(cz(gmms[2] + flg.christoffel_gamma(gmms[0], gmms[1], gmms[1])))

Delta = flg.parallel_canonical(x, v, va, t)
print("Check transport equation with numerical derivatives")
print((flg.parallel_canonical(x, v, va, t+dlt) - Delta)/dlt \
      + flg.christoffel_gamma(gmms[0], gmms[1], Delta))
print("Check transport equation with AD")
print(jvp(lambda t: flg.parallel_canonical(x, v, va, t), (t,), (1.,))[1] \
      + flg.christoffel_gamma(gmms[0], gmms[1], Delta))


1.5543122344752192e-15
3.552713678800501e-15
8.881784197001252e-15
Check transport equation with numerical derivatives
[[ 4.95447319e-07 -6.12083603e-06  6.24417485e-06 -9.78392925e-07
   3.15059757e-06  8.04643643e-08  7.21713000e-06]
 [-3.78544228e-07 -1.22460002e-06 -2.87914499e-06 -2.53627988e-06
  -2.69527048e-07  2.36610822e-07  1.57109348e-06]
 [-2.37453669e-06  1.10263911e-06 -5.31162860e-06  5.90278580e-07
   1.13662463e-06 -1.13939425e-06  5.71448738e-06]
 [ 9.99828388e-07 -1.77227055e-06  4.64163269e-06  1.66418671e-06
   1.12788867e-06  7.16979816e-08 -1.59539661e-06]
 [-1.45116296e-07 -1.39529618e-07 -2.22219320e-06 -2.62430867e-06
  -1.63806998e-06  3.08300231e-06 -4.92815873e-07]
 [ 5.44624470e-07  3.74393244e-06 -5.85930306e-07  1.85464579e-06
  -1.54196954e-06  2.99996532e-06 -3.30938412e-06]
 [ 3.39808842e-07 -8.10638026e-07  1.41371699e-06 -1.64166702e-07
  -7.64218444e-07 -1.08487353e-06 -8.85578843e-06]
 [ 7.83756526e-09  4.09550569e-06 -1.97927796e-06  1.62027391e

A quick check on the speed of parallel. This runs slower in CPU than GPU.

In [None]:
%timeit flg.parallel_canonical(x, v, va, t)

219 ms ± 10.2 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# test a bigger manifold.
Still checking the Levi-Civita connection:

Covariant derivative

In [None]:
jax.config.update("jax_enable_x64", True)

key = random.PRNGKey(0)

dvec = jnp.array([200, 40, 60, 1700])
alp = .5
flg = Flag(dvec, alp)
key = random.PRNGKey(0)

x, key = flg.rand_point(key)
v, key = flg.rand_vec(key, x)
va, key = flg.rand_vec(key, x)

dlt = 1e-6
t = .8

print("CHECK COVARIANT DERIVATIVE RETURNS A VECTOR FIELD")
r1 = jvp(lambda t: flg.proj(x+t*v, va), (0.,), (1.,))[1] + flg.christoffel_gamma(x, v, va)
print(cz(sym(x.T@r1)))
print(cz(flg.proj_m(sym(x.T@r1)) - asym(x.T@r1)))

print("CHECK METRIC COMPATIBILITY")
print(jax.jvp(lambda z: flg.inner(z, flg.proj(z, va), flg.proj(z, va)), (x,), (v,))[1])
print(2*flg.inner(x, va,
                  jax.jvp(lambda x: flg.proj(x, va), (x,), (v,))[1]
                  + flg.christoffel_gamma(x, v, va)))





CHECK COVARIANT DERIVATIVE RETURNS A VECTOR FIELD
3.3861802251067274e-14
5.334621633323877e-14
CHECK METRIC COMPATIBILITY
-2544.7109397373765
-2544.7109397373983


# Check  the parallel transport equation

In [None]:

r1 = flg.exp(x, t*v)

print(cz(jvp(lambda t: flg.exp(x, t*v), (t,), (1.,))[1]
          - flg.dexp(x, v, t, ddexp=False)[1]))

print(cz(jvp(lambda t: flg.dexp(x, v, t)[1], (t,), (1.,))[1]
          - flg.dexp(x, v, t, ddexp=True)[2]))

gmms = flg.dexp(x, v, t, ddexp=True)
print(cz(gmms[2] + flg.christoffel_gamma(gmms[0], gmms[1], gmms[1])))

Delta = flg.parallel_canonical(x, v, va, t)

print("CHECK THE transport equation using numerical derivative")
print(cz((flg.parallel_canonical(x, v, va, t+dlt) - Delta)/dlt \
      + flg.christoffel_gamma(gmms[0], gmms[1], Delta)))

print("CHECK THE transport equation using AD")
print(cz(jvp(lambda t: flg.parallel_canonical(x, v, va, t), (t,), (1.,))[1] \
      + flg.christoffel_gamma(gmms[0], gmms[1], Delta)))


1.6092682741941644e-13
7.389644451905042e-13
3.552713678800501e-12
CHECK THE transport equation using numerical derivative
0.0017154948947197823
CHECK THE transport equation using AD
1.4561685190983553e-12


# check execution time
In this case, $n=2000, d=300$. Note code runs much faster on GPU.

In [None]:
%timeit flg.dexp(x, v, t, ddexp=True)

1.09 s ± 285 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


In [None]:
%timeit flg.parallel_canonical(x, v, va, t)

17 s ± 451 ms per loop (mean ± std. dev. of 7 runs, 1 loop each)


# Compare with Stiefel
Still for $n=2000, d=300$. The flag run should be slightly faster on flag manifolds since we exploit the fact that $\alpha=\frac{1}{2}$ zeros out some matrices. Print out are run times in second

In [None]:
from jax_par_trans.manifolds import Stiefel
stf = Stiefel(flg.shape[0], flg.shape[1], .5)

t0 = perf_counter()
stf.dexp(x, v, t, ddexp=True)[2]
t1 = perf_counter()
print("Stiefel geodesic run time %f seconds" % (t1-t0))

t2 = perf_counter()
flg.dexp(x, v, t, ddexp=True)[2]
t3 = perf_counter()
print("Flag geodesic run time %f seconds" % (t3-t2))

t4 = perf_counter()
stf.parallel(x, v, va, t)
t5 = perf_counter()
print("Stiefel parallel run time %f seconds" % (t5-t4))

t6 = perf_counter()
flg.parallel_canonical(x, v, va, t)
t7 = perf_counter()
print("Flag parallel run time %f seconds" % (t7-t6))


Stiefel geodesic run time 1.048716 seconds
Flag geodesic run time 1.006441 seconds
Stiefel parallel run time 17.481977 seconds
Flag parallel run time 17.049371 seconds


# TEST ISOMETRY

We show the inner product metric is preserved to around $10^{-10}$ accuracy up to $t=15$.

Ambient space is of size $1000 \times 200$.

In [None]:

dvec = jnp.array([60, 40, 100, 900])
# dvec = jnp.array([3, 2, 5])
alp = .5

key = random.PRNGKey(0)

alp = 1.
flg = Flag(dvec, alp)
n = flg.n
d = flg.d
x = jnp.zeros((n, d)).at[:d, :].set(jnp.eye(d))

n_samples = 20

all_smpl = []
for _ in range(n_samples):
    spl, key = flg.rand_vec(key, x)
    all_smpl.append(spl)

all_smpl = jnp.array(all_smpl)
# sp1 = (q@all_smpl[:, None]).reshape(n_samples, n, d)

def cal_cov(gm, smpls):
    mat = jnp.zeros((n_samples,  n_samples))
    for i in range(n_samples):
        for j in range(i+1):
            mat = mat.at[i, j].set(flg.inner(gm, smpls[i, :, :], smpls[j, :, :]))
            if i != j:
                mat = mat.at[j, i].set(mat[i, j])
    return mat

cov_0 = cal_cov(x, all_smpl)

v, key = flg.rand_vec(key, x)
v = v/jnp.sqrt(flg.inner(x, v, v))

cov_diff = []
t = 15
transported = []
for i in range(n_samples):
    transported.append(flg.parallel_canonical(x, v, all_smpl[i, :, :], t))

transported = jnp.array(transported)
gm = flg.exp(x, t*v)
cov_t = cal_cov(gm, transported)
print(cz(cov_t- cov_0))


8.731149137020111e-11


# Check effect of changing n, d, and t
This takes the most time  -  we  move it to last. This may run faster on a home machine than on a free colab session. This is just a sample.

## BACK to the question "Is JAX faster than numpy?":
Answer from [here](https://jax.readthedocs.io/en/latest/faq.html#): If you have GPU, yes, if not, numpy is generally faster.

In [None]:
def test_one_set(flg, key, t_interval, n_samples=10, n_repeats=5):
    ret = []
    d = flg.shape[1]
    for _ in range(n_samples):
        ret_spl = []
        x, key = flg.rand_point(key)
        v, key = flg.rand_vec(key, x)
        va, key = flg.rand_vec(key, x)
        # compile the git
        flg.dexp(x, v, 1.)
        par = flg.parallel_canonical(x, v, va, 1.)

        for t in t_interval:
            ret_t = []
            for _ in range(n_repeats):
                t0 = perf_counter()
                gmms = flg.dexp(x, v, t)
                t1 = perf_counter()
                t_gmms = t1 - t0

                t3 = perf_counter()
                par = flg.parallel_canonical(x, v, va, t)
                t4 = perf_counter()
                t_par = t4 - t3

                # check accuracy:
                geo_man = cz(gmms[0].T@gmms[0] - jnp.eye(d))
                par_tan = cz(sym(gmms[0].T@par))
                par_eq = cz(jvp(lambda t: flg.parallel_canonical(x, v, va, t), (t,), (1.,))[1] +
                            flg.christoffel_gamma(gmms[0], gmms[1], par))

                ret_t.append([t_gmms, t_par, geo_man, par_tan, par_eq])

            ret_spl.append(ret_t)
        ret.append(ret_spl)
    return jnp.array(ret)


def test_time():
    jax.config.update("jax_enable_x64", True)
    key = random.PRNGKey(0)

    # scale this part by d
    dparts = [5, 4, 3]
    dbase = sum(dparts)

    t_interval = jnp.array([.5, 1., 2., 5., 20.])


    # first test, fixed d = 48
    d_list = jnp.array([48])
    n_list = jnp.array([100, 200, 1000])

    alp = .5

    all_ret_0 = {}
    for d in d_list:
        for n in n_list:
            print("Doing n=%d d=%d" % (n, d))
            if n <= d:
                continue
            dvec_d = d//dbase*jnp.array(dparts)
            dvec = jnp.concatenate([dvec_d, jnp.array([n-dvec_d.sum()])])
            print(dvec)
            flg = Flag(dvec, alp)
            ret = test_one_set(flg, key, t_interval, n_samples=5, n_repeats=2)
            all_ret_0[int(d), int(n)] = ret

    tbl = []
    for t_idx in range(t_interval.shape[0]):
        for idx, val in all_ret_0.items():
            tbl.append([idx[1], t_interval[t_idx]] + list(val[:, t_idx, :, :].mean(axis=((0, 1)))))

    raw_tbl = []
    for idx, val in all_ret_0.items():
        for t_idx in range(t_interval.shape[0]):
            for i_s in range(val.shape[0]):
                for i_r in range(val.shape[2]):
                    raw_tbl.append([idx[1], t_interval[t_idx]] + list(val[i_s, t_idx, i_r, :]))

    import pandas as pd
    pd.DataFrame(raw_tbl).to_pickle('flg_by_n.pkl')

    # second test
    d_list = jnp.array([12, 48, 96])
    n_list = jnp.array([1000])

    all_ret_1 = {}
    for d in d_list:
        for n in n_list:
            print("Doing n=%d d=%d" % (n, d))
            if n <= d:
                continue
            dvec_d = d//dbase*jnp.array(dparts)
            dvec = jnp.concatenate([dvec_d, jnp.array([n-dvec_d.sum()])])
            print(dvec)
            flg = Flag(dvec, alp)
            ret = test_one_set(flg, key, t_interval, n_samples=5, n_repeats=2)
            all_ret_1[int(d), int(n)] = ret

    tbl1 = []
    for t_idx in range(t_interval.shape[0]):
        for idx, val in all_ret_1.items():
            tbl1.append([idx[0], t_interval[t_idx]] + list(val[:, t_idx, :, :].mean(axis=((0, 1)))))

    raw_tbl1 = []
    for idx, val in all_ret_1.items():
        for t_idx in range(t_interval.shape[0]):
            for i_s in range(val.shape[0]):
                for i_r in range(val.shape[2]):
                    raw_tbl1.append([idx[0], t_interval[t_idx]] + list(val[i_s, t_idx, i_r, :]))

    pd.DataFrame(raw_tbl1).to_pickle('flg_by_d_1000.pkl')


test_time()

Doing n=100 d=48
[20 16 12 52]
Doing n=200 d=48
[ 20  16  12 152]
Doing n=1000 d=48
[ 20  16  12 952]
Doing n=1000 d=12
[  5   4   3 988]
Doing n=1000 d=48
[ 20  16  12 952]
Doing n=1000 d=96
[ 40  32  24 904]


In [None]:
def display_test():
    import pandas as pd
    jax.config.update("jax_enable_x64", True)
    by_n_tbl = pd.read_pickle('flg_by_n.pkl')
    # by_n_tbl.iloc[:, 2:] = np.array(by_n_tbl.iloc[:, 2:])
    by_n_tbl.iloc[:, 1] = [f"{a:04.1f}" for a in by_n_tbl.iloc[:, 1].values]
    by_n_tbl.columns = ['n', 't', 'geo_time', 'par_time', 'err_geo', 'err_tan', 'err_eq']
    by_n_tbl['log_err_eq'] = [jnp.log10(a) for a in by_n_tbl.err_eq.values]

    by_n_prep = by_n_tbl.pivot_table(index='n',
                                     columns='t',
                                     values=['par_time', 'log_err_eq'],
                                     aggfunc='mean')
    def str1(a):
        return '%.1f' % a

    def str2(a):
        return '%.2f' % a

    # print(by_n_prep.to_latex(formatters=5*[str1] + 5*[str2]))
    display(pd.DataFrame(by_n_prep))
    # alp_tbl = jnp.array([.5, 1.])
    # by_n_tbl.loc[:, 'alp'] = alp_tbl[by_n_tbl.loc[:, 'alp'].values]
    by_d_tbl = pd.read_pickle('flg_by_d_1000.pkl')
    by_d_tbl.iloc[:, 1] = [f"{a:04.1f}" for a in by_d_tbl.iloc[:, 1].values]
    by_d_tbl.columns = ['d', 't', 'geo_time', 'par_time', 'err_geo', 'err_tan', 'err_eq']

    by_d_tbl['log_err_eq'] = [jnp.log10(a) for a in by_d_tbl.err_eq.values]

    by_d_prep = by_d_tbl.pivot_table(index='d',
                                     columns='t',
                                     values=['par_time', 'log_err_eq'],
                                     aggfunc='mean')

    # print(by_d_prep.to_latex(formatters=5*[str1] + 5*[str2]))
    display(pd.DataFrame(by_d_prep))
display_test()

Unnamed: 0_level_0,log_err_eq,log_err_eq,log_err_eq,log_err_eq,log_err_eq,par_time,par_time,par_time,par_time,par_time
t,00.5,01.0,02.0,05.0,20.0,00.5,01.0,02.0,05.0,20.0
n,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
100,-13.208051,-13.10091,-12.969102,-12.750513,-12.359731,0.443269,0.389752,0.381164,0.420295,0.70941
200,-12.988756,-12.913317,-12.764105,-12.530907,-12.144784,0.444969,0.362069,0.383263,0.42757,0.693335
1000,-12.53092,-12.318598,-12.121603,-11.950722,-11.577841,0.495819,0.479026,0.48753,0.545054,1.024165


Unnamed: 0_level_0,log_err_eq,log_err_eq,log_err_eq,log_err_eq,log_err_eq,par_time,par_time,par_time,par_time,par_time
t,00.5,01.0,02.0,05.0,20.0,00.5,01.0,02.0,05.0,20.0
d,Unnamed: 1_level_2,Unnamed: 2_level_2,Unnamed: 3_level_2,Unnamed: 4_level_2,Unnamed: 5_level_2,Unnamed: 6_level_2,Unnamed: 7_level_2,Unnamed: 8_level_2,Unnamed: 9_level_2,Unnamed: 10_level_2
12,-12.778708,-12.577478,-12.346293,-12.277035,-11.851062,0.334972,0.309166,0.309159,0.301368,0.323764
48,-12.53092,-12.318598,-12.121603,-11.950722,-11.577841,0.416719,0.436846,0.482856,0.625868,0.996419
96,-12.418254,-12.256769,-12.081365,-11.753128,-11.387582,0.755964,0.879113,1.309572,2.440413,7.495804
