<a href="https://colab.research.google.com/github/dnguyend/regularMTW/blob/main/colab/DemoSimpleMTW.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Project for the article "Families of costs with zero and nonnegative MTW tensor in optimal transport"
$\newcommand{\ft}{\mathfrak{t}}$
$\newcommand{\R}{\mathbb{R}}$

We compute the MTW tensor (or cross curvature) for the optimal transport problem on $\R^n$ with a cost function of form $c(x, y) = u(x^{\ft}y)$, where $u$ is a scalar function with inverse $s$, $x^{\ft}y$ is a nondegenerate bilinear pairing of vectors x, y belonging to an open subset of $\R^n$.

* In the paper, we show the cross-curvature is expressed as a rational function of derivatives of $s$, and it is zero only if $s$ satisfies a $4$-th-order nonlinear ODE, which could be reduced to a $2$-nd-order ODE depending on two parameter, from here $s$ could be solved explicitly.
* This work book present numerical verification of this result.

In [3]:
!git clone https://github.com/dnguyend/regularMTW.git

Cloning into 'regularMTW'...
remote: Enumerating objects: 20, done.[K
remote: Counting objects: 100% (20/20), done.[K
remote: Compressing objects: 100% (17/17), done.[K
remote: Total 20 (delta 4), reused 15 (delta 3), pack-reused 0[K
Receiving objects: 100% (20/20), 30.56 KiB | 2.55 MiB/s, done.
Resolving deltas: 100% (4/4), done.


In [4]:
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax import jvp, jacfwd, jacrev, random, grad
from regularMTW.src.tools import asym, sym2, sym, Lyapunov, vcat
from regularMTW.src.simple_mtw import (
    GenHyperbolicSimpleMTW,
    LambertSimpleMTW,
    TrigSimpleMTW,
    GHPatterns, LambertPatterns,
    basePotential, grand, splitzero)

from jax.config import config
config.update("jax_enable_x64", True)

  from jax.config import config


There are three family of costs of the form $u(x^{\ft}y)$ with zero MTW curvature. The generalized hyperbolic  family has the form $s(u) = p_0e^{p_1u} + p_2e^{p_3u}$. Depending on values of $p_0, p_1,  p_2, p_3$, the inverse $u$ has several branches, which we organize to the class GHPatterns. They depend on the sign of the parameters $p_i$ and the branch number. There are 26 patterns. Here, we test all patterns numerically for function inversion (between $s$ and $u$, metric compatibility of the connection, and curvature.

Generally the classes are derived from baseSimpleMTW. There is an attribute *rng*, the range of $x^{\ft}y$ where $u$ is defined. The attribute $branch$ specify the branch. For sinh-type costs, the range is the full line $\R$. For other cost, there may be one, two, or infinitely many branches (for trigonometric type).

Generally $u$ is a transcendental function of $s$ and needs to be solved numerically. When we randomly generate a pair $(x, y)\in (\R^n)^2$, $x^{\ft}y$ may not be in range. We rescale $y$ to make it in range

## We implement a  numerical procedure to inverse $s$ to get $u$
It mostly  converges - below we test for 26 patterns

In [12]:
print(len(GHPatterns.patterns))
def gen_in_range(key, GH, x):
    # y such that x.y is in range
    n = GH.n
    yt, key = grand(key, (GH.n+1,))

    s = GH.Adot(x, yt[:n])
    if GH.is_in_range(s):
        return yt[:n], key
    else:
        big = max(GH.rng)
        small = min(GH.rng)
        if jnp.isinf(big):
            big = small + 100
        elif jnp.isinf(small):
            small = big - 100
        return (small + (big-small)/(jnp.abs(yt[-1]) + 1))/s*yt[:n], key

def test_all_hyperbolic():
    """ test all patterns hyperbolic patterns
    """
    key = random.PRNGKey(0)

    def ufunc(self, s):
        """ truncated Newton with line search
        two cases, if there is a critical point then
        either left (branch = -1) and right (branch = 1)
        """
        p0, p1, p2, p3 = self.p
        if not self.is_in_range(s):
            raise ValueError("s is out of range")

        # u0 = p1/(p1-p3)*jnp.log(jnp.abs((p0*p1)/(p2*p3))
        # umin, umax = jnp.sort(self.rng)

        if self.branch is None:
            u0 = 0
        elif self.branch < 0:
            u0 = self.uc - 1
        elif self.branch > 0:
            u0 = self.uc + 1

        tol = 1e-7
        val, grd = self.dsfunc(u0, 1)
        for i in range(20):
            scl = 1.
            newu = u0-scl*(val - s)/grd
            newval, newgrd = self.dsfunc(newu, 1)
            while jnp.isnan(newval) or (jnp.abs(newval-s) > jnp.abs(val-s)):
                scl = scl*.8
                newu = u0-scl*(val - s)/grd
                newval, newgrd = self.dsfunc(newu, 1)
            u0 = newu
            val = newval
            grd = newgrd
            if jnp.abs(val - s) <= tol:
                # print("FOUND", i)
                return u0
        print("NOTFOUND", i, jnp.abs(val - s))
        return u0

    def test_one(kid, n, jj, key):
        pp, br, key = GHPatterns.rand_params(
            key, kid)
        A = jnp.eye(n)
        GH = GenHyperbolicSimpleMTW(n, A, pp, br)

        x, key = GH.grand(key)
        y, key = gen_in_range(key, GH, x)
        q = vcat(x, y)
        u = ufunc(GH, GH.Adot(x, y))
        diff0 = GH.Adot(x, y) - GH.dsfunc(u, 1)[0]
        print(jj, diff0)
        Omg0, key = grand(key, (2*n,))
        Omg1, key = grand(key, (2*n,))

        # Levi Civita connection
        j1 = jvp(lambda q: GH.KMMetric(q, Omg1, Omg1), (q,), (Omg0,))[1]
        j2 = 2*GH.KMMetric(q, Omg1, GH.Gamma(q, Omg0, Omg1))
        diff1 = (j1 - j2)/jnp.abs(j1)

        # curvature
        # OmgNull, key = GH.gennull(key, q)
        Omgx, Omgy = splitzero(Omg0)
        cc = GH.KMMetric(q, Omgx, GH.Curv3(q, Omgx, Omgy, Omgy))
        cc1 = GH.crossCurv(q, Omg0)
        if jnp.abs(cc) > 1e-3:
            diff2 = (cc - cc1)/jnp.abs(cc)
        else:
            diff2 = cc - cc1
        # print(diff2)
        OmgNull, key = GH.gennull(key, q)
        cc2 = GH.crossCurv(q, OmgNull)

        bad = jnp.max(jnp.array([diff0, diff1, diff2, cc2])) > 1e-3

        if bad:
            return bad, x, y, GH, jnp.array([diff0, diff1, diff2, cc2]), key
        return bad, None, None, GH, jnp.array([diff0, diff1, diff2, cc2]), key

    n = 3
    bad = False
    for kid in range(len(GHPatterns.patterns)):
        for jj in range(3):
            bad = test_one(kid, n, jj, key)
            if bad[0]:
                print("BAD", kid, jj)
                break
            else:
                key = bad[-1]
        if bad[0]:
            print("kid %d not close" % kid)
            print(bad[-2])
            # break
        else:
            print("kid %d is good" % kid)
test_all_hyperbolic()

26
0 -1.4210854715202004e-13
1 -1.5493502036889595e-09
2 -3.346656285430072e-12
kid 0 is good
0 1.1509371233842103e-10
1 7.353331099579918e-08
2 8.176543886406762e-10
kid 1 is good
0 9.14823772291129e-14
1 2.783209218648608e-10
2 -2.6645352591003757e-15
kid 2 is good
0 -5.4789506265251475e-14
1 9.075007412207015e-11
2 -1.9359652769779245e-11
kid 3 is good
0 -6.394884621840902e-13
1 -2.7474134078886436e-12
2 -1.852495046250624e-09
kid 4 is good
0 6.012691455836716e-09
1 4.408321530036119e-08
2 8.202309949301156e-09
kid 5 is good
0 1.4040506170903688e-11
1 1.0992924768388201e-08
2 7.708584478231155e-09
kid 6 is good
0 -1.982697853120108e-09
1 -3.422800293195749e-09
2 3.9745454150086346e-10
kid 7 is good
0 4.5671688653214915e-11
1 7.420911884992165e-09
2 1.7763568394002505e-14
kid 8 is good
0 -1.4210854715202004e-14
1 -3.410605131648481e-13
2 1.4210854715202004e-14
kid 9 is good
0 2.7015945036623634e-11
1 8.272227347561056e-10
2 1.0215631840448935e-08
kid 10 is good
0 -1.6833456850662287e

Note there is one case where the inverse function algorithm fails to converge.

# Test Lambert
We implement the Lambert function numerically to solve for $u$ from $s$ in the Lambert case - it converges, and we could test the connection and curvature

In [14]:
def test_all_Lambert():
    key = random.PRNGKey(0)

    # test all patterns:

    def test_one(kid, n, jj, key):
        pp, br, key = LambertPatterns.rand_params(
            key, kid)
        A = jnp.eye(n)
        LB = LambertSimpleMTW(n, A, pp, br)

        x, key = LB.grand(key)
        y, key = gen_in_range(key, LB, x)
        q = vcat(x, y)
        u = LB.ufunc(LB.Adot(x, y))
        diff0 = LB.Adot(x, y) - LB.dsfunc(u, 1)[0]

        Omg0, key = grand(key, (2*n,))
        Omg1, key = grand(key, (2*n,))

        # Levi Civita connection
        j1 = jvp(lambda q: LB.KMMetric(q, Omg1, Omg1), (q,), (Omg0,))[1]
        j2 = 2*LB.KMMetric(q, Omg1, LB.Gamma(q, Omg0, Omg1))
        diff1 = j1 - j2

        # curvature
        # OmgNull, key = LB.gennull(key, q)
        Omgx, Omgy = splitzero(Omg0)
        cc = LB.KMMetric(q, Omgx, LB.Curv3(q, Omgx, Omgy, Omgy))
        diff2 = cc - LB.crossCurv(q, Omg0)
        OmgNull, key = LB.gennull(key, q)
        cc2 = LB.crossCurv(q, OmgNull)

        bad = jnp.max(jnp.array([diff0, diff1, diff2, cc2])) > 1e-5

        if bad:
            return bad, x, y, LB, key
        return bad, None, None, LB, key

    n = 3
    bad = False
    for kid in range(len(LambertPatterns.patterns)):
        for jj in range(5):
            bad = test_one(kid, n, jj, key)
            if bad[0]:
                print("BAD", kid, jj, bad)
                break
            else:
                key = bad[-1]
        if bad[0]:
            print("pattern %d is not close enough" % kid)
            #break
        else:
            print("pattern %d is good" % kid)

test_all_Lambert()


pattern 0 is good
pattern 1 is good
pattern 2 is good
pattern 3 is good
pattern 4 is good
pattern 5 is good
pattern 6 is good
pattern 7 is good


## The exponential-trigonometric case.
We check a  few cases - the numerical inverse converge - and we  can test the connection and the curvature

In [16]:
def test_all_trig():
    key = random.PRNGKey(0)

    # test all patterns:

    def test_one(kid, n, jj, key):
        # k branch with kid
        while True:
            pp, key = grand(key, (4,))
            A = jnp.eye(n)
            TG = TrigSimpleMTW(n, A, pp, branch=kid)
            if (jnp.abs(TG.rng[0]-TG.rng[1]) > 1e-1) \
               and (jnp.abs(TG.rng[0]-TG.rng[1]) < 1e3):
                break

        x, key = TG.grand(key)
        y, key = gen_in_range(key, TG, x)
        q = vcat(x, y)
        u = TG.ufunc(TG.Adot(x, y))
        diff0 = TG.Adot(x, y) - TG.dsfunc(u, 1)[0]

        Omg0, key = grand(key, (2*n,))
        Omg1, key = grand(key, (2*n,))

        # Levi Civita connection
        j1 = jvp(lambda q: TG.KMMetric(q, Omg1, Omg1), (q,), (Omg0,))[1]
        j2 = 2*TG.KMMetric(q, Omg1, TG.Gamma(q, Omg0, Omg1))
        diff1 = 2*(j1 - j2)/(jnp.abs(j1) + jnp.abs(j2))

        # curvature
        # OmgNull, key = TG.gennull(key, q)
        Omgx, Omgy = splitzero(Omg0)
        cc = TG.KMMetric(q, Omgx, TG.Curv3(q, Omgx, Omgy, Omgy))
        cc1 = TG.crossCurv(q, Omg0)
        diff2 = 2*(cc - cc1)/(jnp.abs(cc) + jnp.abs(cc1))
        OmgNull, key = TG.gennull(key, q)
        cc2 = TG.crossCurv(q, OmgNull)

        bad = jnp.max(jnp.array([diff0, diff1, diff2, cc2])) > 1e-3

        if bad:
            print("DIFFS", diff0, diff1, diff2, cc2)
            return bad, x, y, TG, key
        return bad, None, None, TG, key

    n = 4
    for kid in range(-3, 4):
        print("DOING %i" % kid)
        for jj in range(10):
            bad = test_one(kid, n, jj, key)
            if bad[0]:
                print("BAD", kid, jj, bad)
                key = bad[-1]
                break
            else:
                key = bad[-1]
        if bad[0]:
            print("pattern %d is not close enough" % kid)
            break
        else:
            print("pattern %d is good" % kid)
test_all_trig()

DOING -3
pattern -3 is good
DOING -2
pattern -2 is good
DOING -1
pattern -1 is good
DOING 0
pattern 0 is good
DOING 1
pattern 1 is good
DOING 2
pattern 2 is good
DOING 3
pattern 3 is good
