<a href="https://colab.research.google.com/github/dnguyend/regularMTW/blob/main/colab/MTW_tensor_for_SphereAndHyperboloid.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Main Theorems in section 4 on nonnegative MTW tensor on sphere and hyperboloid
A mix of symbolic and numerical differentiation.

## First, the expressions for the three coefficients $R_1, R_{23}, R_4$.

In [1]:
import sympy as sp
s0, s1, s2, s3, s4 = sp.symbols('s0 s1 s2 s3 s4')
p0, p1, p2, p3 = sp.symbols('p0 p1 p2 p3')

u, w = sp.symbols('u w')


D = (- s2-s0*s1**2+s2*s0**2)*(s1**2-s2*s0)*s1**5

R1 = (s0**6*s2*s4 - s0**6*s3**2 - s0**5*s1**2*s4 + 2*s0**5*s1*s2*s3 - s0**5*s2**3 - 3*s0**4*s2*s4 + 3*s0**4*s3**2 + 2*s0**3*s1**2*s4 - 4*s0**3*s1*s2*s3 + 2*s0**3*s2**3
      - 2*s0**2*s1**3*s3 + 2*s0**2*s1**2*s2**2 + 3*s0**2*s2*s4 - 3*s0**2*s3**2 - s0*s1**4*s2 - s0*s1**2*s4 + 2*s0*s1*s2*s3 - s0*s2**3 + s1**6 + 2*s1**3*s3 - 2*s1**2*s2**2 - s2*s4 + s3**2
    ) / (- s2-s0*s1**2+s2*s0**2)/s1**5/(1-s0**2)**2
R1 =  (((s1**2 - s0*s2)*s4 + s0*s3**2 - 2*s1*s2*s3 + s2**3)*(- s2-s0*s1**2+s2*s0**2)*(1-s0**2)**2 \
           + ((s3*s1-s2**2)*(1-s0**2) + s1**2*(s1**2-s2*s0))**2)/D

R23 = ((s3*s1-s2**2)*(1-s0**2) + s1**2*(s1**2-s2*s0))* s1**2*(s1**2-s2*s0)/D
R4 = s1**2*(s1**2-s2*s0)* s1**2*(s1**2-s2*s0)/D

display(D)
display(R1)
display(R23)
display(R4)

# display(sp.simplify((- h2-h0*h1**2+h2*h0**2)))
# display(sp.simplify(h1**2-h2*h0))
# print(sp.factor(R1))
# print(sp.factor(R4))
# print(sp.factor(R23))



s1**5*(-s0*s2 + s1**2)*(s0**2*s2 - s0*s1**2 - s2)

((1 - s0**2)**2*(s0**2*s2 - s0*s1**2 - s2)*(s0*s3**2 - 2*s1*s2*s3 + s2**3 + s4*(-s0*s2 + s1**2)) + (s1**2*(-s0*s2 + s1**2) + (1 - s0**2)*(s1*s3 - s2**2))**2)/(s1**5*(-s0*s2 + s1**2)*(s0**2*s2 - s0*s1**2 - s2))

(s1**2*(-s0*s2 + s1**2) + (1 - s0**2)*(s1*s3 - s2**2))/(s1**3*(s0**2*s2 - s0*s1**2 - s2))

(-s0*s2 + s1**2)/(s1*(s0**2*s2 - s0*s1**2 - s2))

Use $h_0$ to be the function to substitute, and $h_1, h_2,h_3\cdots$ the derivatives.

First the power cost $h_0(u) = (-u)^a-1$ on the sphere. We verify the long expressions in the proof. The actual proof is in the paper, here, we only confirm it symbolically.

In [2]:
a = sp.symbols('a')
h0 = (-u)**a-1
h1 = sp.diff(h0, u)
h2 = sp.diff(h1, u)
h3 = sp.diff(h2, u)
h4 = sp.diff(h3, u)

# substitute the expressions  used in the proof of the theorem.

In [3]:
display(sp.simplify((s1**2-s2*s0 - a*u**(-2)*(s0+1)*(s0+a)).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))
ff1 = s1*s3-s2**2
ff1a = - a**2*(a-1)*(s0+1)**2*u**(-4)
display(sp.simplify((ff1 - ff1a).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))

ff2 = (s1**2-s0*s2)*s4 + s0*s3**2 - 2*s1*s2*s3 + s2**3
ff2a = a**2*(a-1)*u**(-6)*(s0+1)**2*(2*s0-a**2+3*a)

display(sp.simplify((ff2 - ff2a).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))

Da = - a**7*(s0+1)**8*u**(-9)*(s0+a)*(s0+a-1)
display(sp.simplify((D - Da).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))

0

0

0

0

## Checking R1

Sympy cannot simplify $R_1$. We just substitute to show the expression is valid.

In [4]:

expr1 = R1.subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])
f0 = 2*(1-a)*s0**3 +(a**3-4*a**2+9*a-6)*s0**2 + (a**3+6*a**2-12*a+6)*s0 + (a**4-a**3-2*a**2+5*a-2)
display(sp.Poly(f0, s0))
expr1a = ((-u)*f0/(a**4*(s0+1)**2*(s0+a-1))).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])
for av in range(2, 5):
  for ut in range(1, 11):
    uv = -2**(1/av)/10*ut
    print(sp.simplify(sp.expand(expr1 - expr1a)).subs([(a, av), (u, uv)]))


Poly((2 - 2*a)*s0**3 + (a**3 - 4*a**2 + 9*a - 6)*s0**2 + (a**3 + 6*a**2 - 12*a + 6)*s0 + a**4 - a**3 - 2*a**2 + 5*a - 2, s0, domain='ZZ[a]')

1.18010413259022e-15
2.78635697972690e-16
-6.28265579483215e-16
2.27974661977656e-16
0
-1.07754852295086e-16
0
1.97977995927964e-16
-8.71362284698461e-17
0
0
0
0
-8.57562313438692e-16
0
-2.26571682586959e-16
-5.85640521946886e-17
0
2.63188080501189e-17
0
0
7.44179567553618e-13
0
0
-1.00705295698041e-15
6.39793643342800e-16
-1.14639792781909e-16
0
-3.41733796317944e-17
-2.22044604925032e-17


# Sympy can simplify $R_{23}$ and $R_4$

In [5]:
expr23 = R23.subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])
expr23a = ((-u)*((2*a-1)*s0+a**2-a+1)/(a**2*(s0+1)**2*(s0+a-1))).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])
sp.simplify(expr23 - expr23a)

0

In [6]:
expr4 = R4.subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])
expr4a = ((-u)*(s0+a)/(a*(s0+1)**2*(s0+a-1))).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])
sp.simplify(expr4 - expr4a)

0

Proving $f_0$ is positive by substituting $a=b+2$

In [7]:
b = sp.symbols('b')
display(sp.poly(f0.subs(a, b+2), b))
display(sp.factor(sp.poly(f0.subs(a, b+2), b).all_coeffs()[-2]))
display(sp.factor(sp.poly(f0.subs(a, b+2), b).all_coeffs()[-1]))

Poly(b**4 + (s0**2 + s0 + 7)*b**3 + (2*s0**2 + 12*s0 + 16)*b**2 + (-2*s0**3 + 5*s0**2 + 24*s0 + 17)*b - 2*s0**3 + 4*s0**2 + 14*s0 + 8, b, domain='ZZ[s0]')

-(s0 + 1)*(2*s0**2 - 7*s0 - 17)

-2*(s0 - 4)*(s0 + 1)**2

# For Hyperboloid: $s_0 = -(-u)^a$.check the expressions in the proof:

In [8]:
h0 = -(-u)**a
h1 = sp.diff(h0, u)
h2 = sp.diff(h1, u)
h3 = sp.diff(h2, u)
h4 = sp.diff(h3, u)

ex1a = a*s0**2*u**(-2)
display(sp.simplify((s1**2-s0*s2 - ex1a).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))

ex2 = (s1**2-s0*s2)*s4+s0*s3**2 - 2*s1*s2*s3 + s2**3
ex2a = -2*a**2*(1-a)*s0**3*u**(-6)
display(sp.simplify((ex2 - ex2a).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))

ex3 = (s3*s1-s2**2)*(1-s0**2)+s1**2*(s1**2-s0*s2)
ex3a = a**2*u**(-4)*s0**2*((2*a-1)*s0**2+1-a)
display(sp.simplify((ex3 - ex3a).subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))

0

0

0

Square Riemmanian distance cost

In [9]:
u, w = sp.symbols('u w')


h0 = sp.cos(w)
# 1/2*w^2 = u
dudw = w


# h1 = sp.factor(sp.diff(h0, w)/dudw)
# h2 = sp.factor(sp.diff(h1, w)/dudw)
# h3 = sp.factor(sp.diff(h2, w)/dudw)
# h4 = sp.factor(sp.diff(h3, w)/dudw)

h1 = sp.diff(h0, w)/dudw
h2 = sp.diff(h1, w)/dudw
h3 = sp.diff(h2, w)/dudw
h4 = sp.diff(h3, w)/dudw


# verifying the formulas,
R1 is hard for sympy, we help by providing the expression. Sympy gives expressions for R23 and R4

In [10]:
R1s = sp.simplify(sp.factor(R1.subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)])))
R1a = (4*w**2+ w*sp.sin(2*w) - 3 + 3*sp.cos(2*w))/(w**2*sp.sin(w)**2)
display(sp.simplify(R1s - R1a))

display(sp.factor(sp.simplify(R23.subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)]))))
display(sp.factor(sp.simplify(R4.subs([(s0, h0), (s1, h1), (s2, h2), (s3, h3), (s4, h4)]))))

0

-2*(w*cos(w) - sin(w))/sin(w)**3

w*(2*w - sin(2*w))/(2*sin(w)**4)

# Check numerically


In [11]:
!git clone https://github.com/dnguyend/regularMTW.git

Cloning into 'regularMTW'...
remote: Enumerating objects: 24, done.[K
remote: Counting objects: 100% (24/24), done.[K
remote: Compressing objects: 100% (20/20), done.[K
remote: Total 24 (delta 5), reused 14 (delta 3), pack-reused 0[K
Receiving objects: 100% (24/24), 36.43 KiB | 2.80 MiB/s, done.
Resolving deltas: 100% (5/5), done.


In [12]:
import jax.numpy as jnp
import jax.numpy.linalg as jla
from jax.scipy.optimize import minimize
from jax import jvp, jacfwd, jacrev, random, grad
from regularMTW.src.tools import asym, sym2, sym, Lyapunov, vcat, LambertW
from regularMTW.src.space_form import SimpleKH, grand, splitzero
from regularMTW.src.simple_mtw import (GenHyperbolicSimpleMTW,
                                       LambertSimpleMTW)


from jax.config import config
config.update("jax_enable_x64", True)


def deriv_antenna(s, pp):
    """ function of form s = p0exp(p1u) + p2
    """
    p0, p1, p2 = pp
    u = 1/p1*jnp.log((s-p2)/p0)
    h0 = p0*jnp.exp(p1*u) + p2
    h1 = p0*p1*jnp.exp(p1*u)
    h2 = p0*p1**2*jnp.exp(p1*u)
    h3 = p0*p1**3*jnp.exp(p1*u)
    h4 = p0*p1**4*jnp.exp(p1*u)

    return u, h0, h1, h2, h3, h4


def deriv_hyperb(s, pp, branch=None):
    """ function of form s = p0exp(-p3u) + p2exp(p3u)
    thus, p1 = -p3
    if p0p2 < 0 then always has root, one branch, so ignore branch
    if p0p2 > 0 then need s > 4p0p2, two branches
    """
    p0, p1, p2 = pp
    p3 = - p1
    if p0*p2 < 0:
        rt = (s + jnp.sign(p2)*jnp.sqrt(s**2-4*p2*p0))/(2*p2)
        u = 1/p3*jnp.log(rt)
    else:
        u = 1/p3*jnp.log((s + branch*jnp.sign(p2)*jnp.sqrt(s**2-4*p2*p0))/(2*p2))
    h0 = p0*jnp.exp(p1*u) + p2*jnp.exp(p3*u)
    h1 = p0*p1*jnp.exp(p1*u) + p2*p3*jnp.exp(p3*u)
    h2 = p0*p1**2*jnp.exp(p1*u) + p2*p3**2*jnp.exp(p3*u)
    h3 = p0*p1**3*jnp.exp(p1*u) + p2*p3**3*jnp.exp(p3*u)
    h4 = p0*p1**4*jnp.exp(p1*u) + p2*p3**4*jnp.exp(p3*u)

    return u, h0, h1, h2, h3, h4


def deriv_quad_hyperb(s, pp, branch):
    """ function of form s = p0exp(p1u) + p2exp(2p1u)
    thus, p3 = 2p1

    range is between 0 and p2infty ( at u = p1 infty). So s needs to be
    of same sign as p2.

    if p0p2 > 0 then one branch, so ignore branch
    if p0p2 < 0 then need s > 4p0p2, two branches
    """
    p0, p1, p2 = pp
    if jnp.sign(p2) != jnp.sign(s):
        raise ValueError("s and p2 should have same sign")
    p3 = 2*p1
    if p0*p2 < 0:
        u = 1/p1*jnp.log((jnp.sign(p2)*jnp.sqrt(4*s*p2+p0**2)-p0)/(2*p2))
    else:
        u = 1/p1*jnp.log((branch*jnp.sign(p2)*jnp.sqrt(4*s*p2+p0**2)-p0)/(2*p2))
    h0 = p0*jnp.exp(p1*u) + p2*jnp.exp(p3*u)
    h1 = p0*p1*jnp.exp(p1*u) + p2*p3*jnp.exp(p3*u)
    h2 = p0*p1**2*jnp.exp(p1*u) + p2*p3**2*jnp.exp(p3*u)
    h3 = p0*p1**3*jnp.exp(p1*u) + p2*p3**3*jnp.exp(p3*u)
    h4 = p0*p1**4*jnp.exp(p1*u) + p2*p3**4*jnp.exp(p3*u)

    return u, h0, h1, h2, h3, h4


def deriv_gh(s, GH):
    u = GH.ufunc(s)
    h0, h1, h2, h3, h4 = GH.dsfunc(u, 4)
    return u, h0, h1, h2, h3, h4


def deriv_lambertA(s, LB):
    u = LB.ufunc(s)
    h0, h1, h2, h3, h4 = LB.dsfunc(u, 4)
    return u, h0, h1, h2, h3, h4


def deriv_lambertB(s, aa, branch):
    # function is (a0 + a1u)*exp(a2u) = s
    a0, a1, a2 = aa
    u = 1/a2*LambertW(a2*s*jnp.exp(a0*a2/a1)/a1, 0)[0] - a0/a1
    ex = jnp.exp(a2*u)
    h = [(a0 + a1*u)*ex]
    a, b = a0, a1
    for i in range(1, 5):
        h.append(h[-1]*a2 + b*ex)
        a = a*a2 + b
        b = b*a2

    return u, h[0], h[1], h[2], h[3], h[4]


def deriv_power_al(s, eps, al, al0):
    u = - jnp.abs(eps*s-eps*al0)**(1/al)

    h0 = eps*(-u)**al + al0
    h1 = -al*eps*(-u)**(al-1)
    h1 = al*(h0-al0)/u
    if al == 2:
        h2 = al*(al-1)*eps
        h3 = 0
        h4 = 0
    else:
        h2 = (al-1)*al*(h0-al0)/u**2
        if al == 3:
            h3 = - al*(al-1)*(al-2)*eps
            h4 = 0.
        else:
            h3 = al*(al-1)*(al-2)*(h0-al0)/u**3
            if al == 4:
                h4 = al*(al-1)*(al-2)*(al-3)*eps
            else:
                h4 = al*(al-1)*(al-2)*(al-3)*(h0-al0)/u**4
    return u, h0, h1, h2, h3, h4


def deriv_power(s, al, p0=1):
    u = - 1/p0*jnp.abs(s+1)**(1/al)

    h0 = jnp.abs(p0*u)**al - 1
    h1 = - al*p0*jnp.abs(p0*u)**(al-1)
    if al == 2:
        h2 = al*(al-1)*p0**2
        h3 = 0
        h4 = 0
    else:
        h2 = (al-1)*al*p0**2*jnp.abs(p0*u)**(al-2)
        if al == 3:
            h3 = - al*(al-1)*(al-2)*p0**3
            h4 = 0.
        else:
            h3 = - al*(al-1)*(al-2)*p0**3*jnp.abs(p0*u)**(al-3)
            if al == 4:
                h4 = al*(al-1)*(al-2)*(al-3)*p0**4
            else:
                h4 = al*(al-1)*(al-2)*(al-3)*p0**4*jnp.abs(p0*u)**(al-4)
    return u, h0, h1, h2, h3, h4



  from jax.config import config


## Check Gauss-Codazzi

In [13]:

def check_gauss_codazzi():
    n = 5

    # power case
    key = random.PRNGKey(0)

    eps = 1
    nm = 1

    # need nm = 0 then eps != -1
    # nm = n then eps != 1
    A = jnp.diag(jnp.array(nm*[-1.] + (n-nm)*[1.]))

    ppt, key = grand(key, (4,))
    ppt = jnp.abs(ppt)
    pex = sorted([ppt[1], ppt[3]], reverse=True)
    pp = [-ppt[0], -pex[0], ppt[2], -pex[1]]

    GH = GenHyperbolicSimpleMTW(n, A, pp, branch=-1)
    SKH = SimpleKH(n, A, eps, lambda s: deriv_gh(s, GH))
    qs, key = SKH.gen_qs(key)

    Omg1, key = grand(key, (2*n,))
    Omg2, key = grand(key, (2*n,))
    Omg3, key = grand(key, (2*n,))
    Omg4, key = grand(key, (2*n,))

    print(GH.KMMetric(qs, Omg1, Omg2))
    print(SKH.KMMetric(qs, Omg1, Omg2))

    print(jvp(lambda qq: GH.KMMetric(qq, Omg2, Omg2), (qs,), (Omg1,))[1])
    print(2*GH.KMMetric(qs, Omg2, GH.Gamma(qs, Omg1, Omg2)))

    Xi1 = SKH.projSphere(qs, Omg1)
    Xi2 = SKH.projSphere(qs, Omg2)
    # Xi3 = SKH.projSphere(qs, Omg3)
    # Xi4 = SKH.projSphere(qs, Omg4)

    print(jvp(lambda qq: SKH.KMMetric(qq, Xi2, Xi2), (qs,), (Xi1,))[1])
    print(2*GH.KMMetric(qs, Xi2, GH.Gamma(qs, Xi1, Xi2)))
    print(2*SKH.KMMetric(qs, Xi2, SKH.GammaAmbient(qs, Xi1, Xi2)))
    print(2*SKH.KMMetric(qs, Xi2, SKH.Gamma(qs, Xi1, Xi2)))

    Xi1x, Xi1y = splitzero(Xi1)

    def Curv3(self, q, Omg1, Omg2, Omg3):
        D1 = jvp(lambda q: self.Gamma(q, Omg2, Omg3), (q,), (Omg1,))[1]
        D2 = jvp(lambda q: self.Gamma(q, Omg1, Omg3), (q,), (Omg2,))[1]
        G1 = self.Gamma(q, Omg1, self.Gamma(q, Omg2, Omg3))
        G2 = self.Gamma(q, Omg2, self.Gamma(q, Omg1, Omg3))
        return D1 - D2 + G1 - G2

    print(SKH.KMMetric(qs, Xi1x, Curv3(SKH, qs, Xi1x, Xi1y, Xi1y)))
    print(SKH.crossCurvSphere(qs, Xi1)[0])

    # second fundamental form
    def TwoSphere(self, qs, Xi1, Xi2):
        GA = self.GammaAmbient(qs, Xi1, Xi2)
        return self.DprojSphere(qs, Xi1, Xi2) + GA - self.projSphere(qs, GA)

    # Gauss Codazzi
    print(GH.crossCurv(qs, Xi1) + GH.KMMetric(qs,
                                              TwoSphere(SKH, qs, Xi1x, Xi1x),
                                              TwoSphere(SKH, qs, Xi1y, Xi1y)))

check_gauss_codazzi()


4.92723097515855
4.92723097515855
-35.59008257244199
-35.590082574318195
-59.203216613387646
-59.20321661516178
-59.20321661516178
-59.203216615150424
-48.203979785520254
-48.20397978577626
-48.2039797857704


# Check Sphere and Hyperboloid numerically

* There are some instances the inverse function fails to converge sufficiently fast, resulting in Nan or not match - we will ignore

In [14]:

def testSphere():
    n = 5

    # power case
    key = random.PRNGKey(0)

    eps = 1
    nm = 0

    # need nm = 0 then eps != -1
    # nm = n then eps != 1
    A = jnp.diag(jnp.array(nm*[-1.] + (n-nm)*[1.]))

    # Antenna case,
    # need p0 <0, p1 < 0, p2 >0

    bad = False
    for i in range(100):
        sk, key = random.split(key)
        al,  p0 = random.uniform(sk, (2,), minval= 2, maxval=10)
        p0 = p0/5

        skp = SimpleKH(n, A, eps, lambda q: deriv_power(q, al, p0))
        qs, key = skp.gen_qs(key)
        Xinull, key = skp.gennull_sphere(key, qs)
        # print(ska.KMMetric(qs, Xinull, Xinull))
        curv, sR1, sR23, R4 = skp.crossCurvSphere(qs, Xinull)
        if not jnp.all(jnp.array([curv, sR1, sR23, R4]) > 0):
            print("BAD")
            bad = True
            break

    if bad:
        print("BAD")
    else:
        print("GOOD")

testSphere()


GOOD


# Check GH

In [16]:

def test_hyperboloid_GH_Antenna():
    n = 5
    key = random.PRNGKey(0)

    # nm is number of negative eigenvalues
    # nm = 0 is the sphere, nm = 1 is the hyperboloid model
    # constraint is - sum x_{negative) + sum x_pos = eps
    eps = -1
    nm = 1

    # need nm = 0 then eps != -1
    # nm = n then eps != 1
    A = jnp.diag(jnp.array(nm*[-1.] + (n-nm)*[1.]))

    # Antenna case,
    # need p0 <0, p1 < 0, p2 >0

    bad = False
    for i in range(1000):
        ppt, key = grand(key, (3,))
        ppt = jnp.abs(ppt)
        pp = [-ppt[0], -ppt[1], ppt[2]]
        ska = SimpleKH(n, A, eps, lambda q: deriv_antenna(q, pp[:3]))
        qs, key = ska.gen_qs(key)
        Xinull, key = ska.gennull_sphere(key, qs)
        # print(ska.KMMetric(qs, Xinull, Xinull))
        curv, sR1, sR23, R4 = ska.crossCurvSphere(qs, Xinull)
        if not jnp.all(jnp.array([curv, sR1, sR23, R4]) > 0):
            print("BAD")
            bad = True
            break

    if bad:
        print("BAD")
    else:
        print("GOOD")

test_hyperboloid_GH_Antenna()

GOOD


In [17]:
def test_hyperboloid_GH_1():
    n = 5
    key = random.PRNGKey(0)

    # nm is number of negative eigenvalues
    # nm = 0 is the sphere, nm = 1 is the hyperboloid model
    # constraint is - sum x_{negative) + sum x_pos = eps
    eps = -1
    nm = 1

    # need nm = 0 then eps != -1
    # nm = n then eps != 1
    A = jnp.diag(jnp.array(nm*[-1.] + (n-nm)*[1.]))

    # SINH-type
    # p0 <0, p1 < 0, p2 >0, p1+p3 < 0

    bad = False
    for i in range(1000):
        ppt, key = grand(key, (4,))
        ppt = jnp.abs(ppt)
        pex = sorted([ppt[1], ppt[3]], reverse=True)
        pp = [-ppt[0], -pex[0], ppt[2], pex[1]]
        GH = GenHyperbolicSimpleMTW(n, A, pp)

        sksinh = SimpleKH(n, A, eps, lambda s: deriv_gh(s, GH))
        qs, key = sksinh.gen_qs(key)
        Xinull, key = sksinh.gennull_sphere(key, qs)
        curv, sR1, sR23, R4 = sksinh.crossCurvSphere(qs, Xinull)
        if not jnp.all(jnp.array([curv, sR1, sR23, R4]) > 0):
            print("BAD", curv, sR1, sR23, R4)
            bad = True
            break

    if bad:
        print("BAD")
    else:
        print("GOOD")
test_hyperboloid_GH_1()

GOOD


In [21]:
def test_hyperboloid_GH_2():
    n = 5
    key = random.PRNGKey(0)

    # nm is number of negative eigenvalues
    # nm = 0 is the sphere, nm = 1 is the hyperboloid model
    # constraint is - sum x_{negative) + sum x_pos = eps
    eps = -1
    nm = 1

    # need nm = 0 then eps != -1
    # nm = n then eps != 1
    A = jnp.diag(jnp.array(nm*[-1.] + (n-nm)*[1.]))

    # branch case
    # skhypbranch. Need p0 < 0, p_1 < 0, p2 > 0, p1 < p3 <=0
    # like [-1, -1, 2, -.5]
    bad = False
    for i in range(1000):
        ppt, key = grand(key, (4,))
        ppt = jnp.abs(ppt)
        pex = sorted([ppt[1], ppt[3]], reverse=True)
        pp = [-ppt[0], -pex[0], ppt[2], -pex[1]]

        GH = GenHyperbolicSimpleMTW(n, A, pp, branch=-1)
        if GH.rng[1] > 1e8:
            continue
        skhypbranch = SimpleKH(n, A, eps, lambda s: deriv_gh(s, GH))
        qs, key = skhypbranch.gen_qs(key)
        Xinull, key = skhypbranch.gennull_sphere(key, qs)
        # print(skhypbranch.KMMetric(qs, Xinull, Xinull))
        curv, sR1, sR23, R4 = skhypbranch.crossCurvSphere(qs, Xinull)
        if jnp.isnan(curv):
            print("NAN", i, curv, sR1, sR23, R4)
        elif not jnp.all(jnp.array([curv, sR1, sR23, R4]) > 0):
            print("BAD", i, curv, sR1, sR23, R4)
            bad = True
            break

    if bad:
        print("BAD")
    else:
        print("GOOD")

test_hyperboloid_GH_2()

NOTFOUND 19 1.5349633830652465e-07
NOTFOUND 19 1.5349633830652465e-07
NOTFOUND 19 2.0665719802082094e-07
NOTFOUND 19 2.0665719802082094e-07
GOOD


In [20]:
def test_hyperboloid_Lambert():
    n = 5
    key = random.PRNGKey(0)

    # nm is number of negative eigenvalues
    # nm = 0 is the sphere, nm = 1 is the hyperboloid model
    # constraint is - sum x_{negative) + sum x_pos = eps
    eps = -1
    nm = 1

    # need nm = 0 then eps != -1
    # nm = n then eps != 1
    A = jnp.diag(jnp.array(nm*[-1.] + (n-nm)*[1.]))

    # Lambert case
    # a0, a1, a2. No constraint on a0, a1 > 0, a2 <0
    # example [-1, 1, -2]

    bad = False
    for i in range(1000):
        ppt, key = grand(key, (3,))
        aa = [ppt[0], jnp.abs(ppt[1]), - jnp.abs(ppt[2])]
        skl = SimpleKH(n, A, eps, lambda s: deriv_lambertB(s, aa, 0))
        qs, key = skl.gen_qs(key)
        Xinull, key = skl.gennull_sphere(key, qs)
        # print(skhypbranch.KMMetric(qs, Xinull, Xinull))
        curv, sR1, sR23, R4 = skl.crossCurvSphere(qs, Xinull)
        if jnp.isnan(curv):
            print("NAN, will ignore", i, curv, sR1, sR23, R4)
        elif not jnp.all(jnp.array([curv, sR1, sR23, R4]) > 0):
            print("BAD", i, curv, sR1, sR23, R4)
            bad = True
            break
    if bad:
        print("BAD")
    else:
        print("GOOD")
test_hyperboloid_Lambert()

NAN, will ignore 343 nan nan nan nan
GOOD
