<a href="https://colab.research.google.com/github/dnguyend/EmbeddedGeometry/blob/main/colab/KMCTests.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Overview
We use [Jax](https://jax.readthedocs.io/en/latest/index.html) for numerical derivative, in particular, [jvp](https://jax.readthedocs.io/en/latest/_autosummary/jax.jvp.html) is the command for directional derivative.



To find directional derivative of a function $c(x, y)$, in variable $x$ in direction $\xi$ we only need to write

# jvp(lambda x: c(x, y), (x,), (xi,))
where **lamda x: c(x, y)**  creates a function in $x$ and $y$ is becomes a parameter. JVP stands for "Jacobian vector product".

In [1]:
!git clone https://github.com/dnguyend/EmbeddedGeometry.git

Cloning into 'EmbeddedGeometry'...
remote: Enumerating objects: 12, done.[K
remote: Counting objects: 100% (12/12), done.[K
remote: Compressing objects: 100% (9/9), done.[K
remote: Total 12 (delta 2), reused 5 (delta 1), pack-reused 0[K
Unpacking objects: 100% (12/12), 17.36 KiB | 4.34 MiB/s, done.


In [2]:
import jax.numpy as jnp
from jax import jvp, jacfwd, random
from jax.config import config

from IPython.display import display

from EmbeddedGeometry.src.kmc import (sym2, Teval, KMC, KMCDivergence, KMCSphere, KMCAntenna)


config.update("jax_enable_x64", True)


Import the kmc module which contains main code. We can view the code by running !cat kmc.py.

In [3]:
import EmbeddedGeometry.src.kmc as kmc

 Function $kmc.KMC.kmc$ is the inner product. We show the source code of inner product, Gamma, Curvature, conjugate of Gamma ( for the conjugate connection in the Euclidean ambient

In [4]:
import inspect
print(inspect.getsource(kmc.KMC.kmc))
print(inspect.getsource(kmc.KMC.Gamma))
print(inspect.getsource(kmc.KMC.Curvature))
print(inspect.getsource(kmc.KMC.conjGamma))

    def kmc(self, q, Omg1, Omg2):
        """The metric. return a number
        """
        n = self.n        
        x, y = self.split(q)
        C = self.bddc(x, y, self.params)
        return -0.5*jnp.sum(Omg1[:n]*(C@Omg2[n:])) \
            - 0.5*jnp.sum(Omg2[:n]*(C@Omg1[n:]))

    def Gamma(self, q, Omg1, Omg2):
        n = self.n
        x, y = self.split(q)                
        
        C, Cx = jvp(lambda x: self.bddc(x, y, self.params), (x,), (Omg1[:n],))
        Cy = jvp(lambda y: self.bddc(x, y, self.params), (y,), (Omg1[n:],))[1]
        return jnp.concatenate([jla.solve(C.T, Cx.T@Omg2[:n]),
                                jla.solve(C, Cy@Omg2[n:])])

    def Curvature(self, q, Omg1, Omg2, Omg3):
        n = self.n
        x, y = self.split(q)                
        
        DG1 = jvp(lambda q: self.Gamma(q, Omg2, Omg3), (q, ), (Omg1, ))[1]
    
        DG2 = jvp(lambda q: self.Gamma(q, Omg1, Omg3), (q, ), (Omg2, ))[1]
        GG1 = self.Gamma(q, Omg1, self.Gamma(q, Om

First, we test the KMC metric on two open subsets. We generate the function $c$, a vector polynomial of dimension up to 4

In [5]:
n = 3

key = random.PRNGKey(0)    
key, sk = random.split(key)

Cvec = random.normal(sk, (2*n,))
key, sk = random.split(key)

Cmat = random.normal(sk, (2*n, 2*n))

key, sk = random.split(key)    
Ctensor = random.normal(sk, (2*n, 2*n, 2*n))
key, sk = random.split(key)        
Ctensor4 = random.normal(sk, (2*n, 2*n, 2*n, 2*n))

# making a complicated c

def c(x, y, params):
    q = jnp.block([x, y])
    Cvec, Cmat, Ctensor, Ctensor4 = params
    return Teval(Cvec, [q]) + Teval(Cmat, [q, q]) + Teval(Ctensor, [q, q, q]) \
        + Teval(Ctensor4, [q, q, q, q])
    




Generate a random vector then test metric compatibility. 
Let $q$ be the manifold variable.
Take the vector field $Y = Aq +B$ for a matrix $A$ of size $2n\times 2n$, $B$ a vector of size $2n$.
We check

$D_{\omega} \langle Y, Y\rangle_{KM} = 2\langle Y, \nabla_{\omega}Y\rangle_{KM}$

In [6]:
key, sk = random.split(key)
q = random.normal(sk, (2*n, ))

KM = KMC(n, c, [Cvec, Cmat, Ctensor, Ctensor4])


key, sk = random.split(key)   
A = random.normal(sk, (2*n, 2*n))

key, sk = random.split(key)   
B = random.normal(sk, (2*n, ))

key, sk = random.split(key)   
Omg =  random.normal(sk, (2*n,))

# a random point on the manifold
key, sk = random.split(key)
q = random.normal(sk, (2*n, ))

# a random point on the manifold
display(jvp(lambda q: KM.kmc(q, A@q+B, A@q+B), (q,), (Omg,))[1])
display(2*KM.kmc(q, A@q+B, A@Omg + KM.Gamma(q, Omg, A@q+B)))




Array(-15.13011365, dtype=float64)

Array(-15.13011365, dtype=float64)

Curvature

In [7]:
key, sk = random.split(key)    
Omg1 = random.normal(sk, (2*n,))

key, sk = random.split(key)   
Omg2 = random.normal(sk, (2*n,))

key, sk = random.split(key)   
Omg3 = random.normal(sk, (2*n,))

key, sk = random.split(key)   
Omg4 = random.normal(sk, (2*n,))


KM = KMC(n, c, [Cvec, Cmat, Ctensor, Ctensor4])

key, sk = random.split(key)    
Omg1 = random.normal(sk, (2*n,))

key, sk = random.split(key)   
Omg2 = random.normal(sk, (2*n,))

key, sk = random.split(key)   
Omg3 = random.normal(sk, (2*n,))

key, sk = random.split(key)   
Omg4 = random.normal(sk, (2*n,))
    
print(KM.Curv4(q, Omg1, Omg2, Omg3, Omg4))
print(KM.Curv4(q, Omg2, Omg1, Omg3, Omg4))
print(KM.Curv4(q, Omg1, Omg2, Omg4, Omg3))
print(KM.Curv4(q, Omg3, Omg4, Omg1, Omg2))    


248.038939041822
-248.03893904182198
-248.0389390418221
248.03893904182183


# The Generalized Reflector Antenna

We test by two methods: we using the geometry of the sphere, with the general formulas, and the simplified formulas for the reflector antenna. We first generate two objects:

In [8]:
n = 5
key = random.PRNGKey(0)    
key, sk = random.split(key)
Lbd = sym2(random.normal(sk, (n, n)))
Ant = KMCAntenna(Lbd)

def c(x, y, Lambda):
    xmy = (1-jnp.sum(x*(Lambda@y)))        
    return -0.5*jnp.log(xmy)

KMA = KMC(n, c, Lbd)
KMS = KMCSphere(KMA)



Comparing the code, for example Gamma. KMS uses automatic differentiation, Ant uses the simplification in the paper

In [9]:

print(inspect.getsource(KMS.Gamma))
print(inspect.getsource(Ant.Gamma))

    def Gamma(self, q, Xi1, Xi2):
        n = self.n
        x, y = self.split(q)
        
        C, Cx = jvp(lambda x: self.bddc(x, y, self.params), (x,), (Xi1[:n],))
        Cy = jvp(lambda y: self.bddc(x, y, self.params), (y,), (Xi1[n:],))[1]
        iC = jla.inv(C)
        yiCx = jnp.sum(y*(iC@x))
        return jnp.concatenate(
            [iC.T@Cx.T@Xi2[:n] - iC.T@y*(
                jnp.sum(x*(iC.T@Cx.T@Xi2[:n])) - jnp.sum(Xi1[:n]*Xi2[:n])) / yiCx,
             iC@Cy@Xi2[n:] - iC@x*(jnp.sum(y*(iC@Cy@Xi2[n:])) - jnp.sum(Xi1[n:]*Xi2[n:])) / yiCx
             ])

    def Gamma(self, q, Xi, Eta):
        x, y = self.split(q)
        n = self.n
        iLbd = self.iLambda
        Lbd = self.Lambda
        
        xmy = (1-jnp.sum(x*(Lbd@y)))
        xILbdy = (1-jnp.sum(q[:n]*(iLbd@q[n:])))

        return jnp.block(
            [1/(xmy)*jnp.sum((Lbd@y)*Eta[:n])*Xi[:n]
             + 1/(xmy)*jnp.sum((Lbd@y)*Xi[:n])*Eta[:n]
             - (iLbd@q[n:] - q[:n])*jnp.sum(Xi[:n]*Eta[:n])/

## Check metric compatibility
Check projection and Gamma are metric compatible. Note Xi is a tangent vector, so $q[:n]^TXi[:n]=0, q[n:]^TXi[n:]=0$.

* For the vector field, we use $z\mapsto \Pi(z)\eta$

In [10]:
key, sk = random.split(key)    
q = Ant.randQpoint(sk)

key, sk = random.split(key)    
Xi = Ant.randvec(sk, q)

key, sk = random.split(key)    
Omg = random.normal(sk, (2*n,))

print("CHECK TANGENT")
print(jnp.sum(q[:n]*Xi[:n]), jnp.sum(q[n:]*Xi[n:]))

print("CHECK PROJECTION")
print(KMS.proj(q, Omg) - Ant.proj(q, Omg))
print(Ant.kmc(q, Ant.proj(q, Omg), Xi)
      - Ant.kmc(q, Omg, Xi))    


key, sk = random.split(key)    
Eta = Ant.randvec(sk, q)


# check KMS is same as Ant
print("KMS = Ant")
display(KMS.Gamma(q, Xi, Eta) - Ant.Gamma(q, Xi, Eta))
print("CHECK CONNECTION IS METRIC COMPATIBLE")
# Metric compatible
display(jvp(lambda z: Ant.kmc(z, Ant.proj(z, Eta), Ant.proj(z, Eta)), (q,), (Xi,))[1])
display(2*Ant.kmc(q, Eta, Ant.Gamma(q, Xi, Eta)))


CHECK TANGENT
1.1102230246251565e-16 2.7755575615628914e-16
CHECK PROJECTION
[-7.10542736e-15 -1.77635684e-15 -5.55111512e-15  3.19744231e-14
 -1.55431223e-15 -1.11022302e-16  5.55111512e-17 -1.11022302e-16
 -2.22044605e-16  0.00000000e+00]
0.0
KMS = Ant


Array([-2.57571742e-14, -8.21565038e-15, -2.13162821e-14,  1.35003120e-13,
       -7.10542736e-15, -4.32986980e-15,  2.10942375e-15, -4.88498131e-15,
       -8.43769499e-15,  1.83186799e-15], dtype=float64)

CHECK CONNECTION IS METRIC COMPATIBLE


Array(-0.52433215, dtype=float64)

Array(-0.52433215, dtype=float64)

## Check curvature
For Curvature: Generate 4 vectors

In [11]:
key, sk = random.split(key)    
Xi1 = Ant.randvec(sk, q)

key, sk = random.split(key)   
Xi2 = Ant.randvec(sk, q)

key, sk = random.split(key)   
Xi3 = Ant.randvec(sk, q)

key, sk = random.split(key)   
Xi4 = Ant.randvec(sk, q)
print("ANT and KMS are compatible")
print(Ant.Curv4(q, Xi1, Xi2, Xi3, Xi4) - KMS.Curv4(q, Xi1, Xi2, Xi3, Xi4))

ANT and KMS are compatible
-8.881784197001252e-14


### BIANCHI 
The [Bianchi identities](https://en.wikipedia.org/wiki/Riemann_curvature_tensor#Symmetries_and_identities)

In [12]:
print("FIRST BIANCHI")
print(Ant.Curv4(q, Xi1, Xi2, Xi3, Xi4))
print(Ant.Curv4(q, Xi2, Xi1, Xi3, Xi4))
print(Ant.Curv4(q, Xi1, Xi2, Xi4, Xi3))
print(Ant.Curv4(q, Xi3, Xi4, Xi1, Xi2))

def NablaR(q, Xi1, Xi2, Xi3, xi4):
    def xifunc(q, Xi):
        return Ant.proj(q, Xi)

    def cov(q, Xi1, Xi2):
        return jvp(lambda x: Ant.proj(x, Xi2), (q,), (Xi1,))[1] +\
            Ant.Gamma(q, Xi1, Xi2)

    Tot = jvp(lambda q: Ant.Curvature(q,
                                      xifunc(q, Xi2),
                                      xifunc(q, Xi3),
                                      xifunc(q, Xi4)),
              (q,), (Xi1,))[1] \
        + Ant.Gamma(q, Xi1, Ant.Curvature(q, Xi2, Xi3, Xi4))
    R12 = Ant.Curvature(q, cov(q, Xi1, Xi2), Xi3, Xi4)
    R13 = Ant.Curvature(q, Xi2, cov(q, Xi1, Xi3), Xi4)
    R14 = Ant.Curvature(q, Xi2, Xi3, cov(q, Xi1, Xi4))
    return Tot - R12 - R13 - R14
print("SECOND BIANCHI")
f1 = NablaR(q, Xi1, Xi2, Xi3, Xi4)
f2 = NablaR(q, Xi2, Xi3, Xi1, Xi4)
f3 = NablaR(q, Xi3, Xi1, Xi2, Xi4)
print(f1 + f2 + f3)



FIRST BIANCHI
-28.821347577031478
28.821347577031478
28.821347577031382
-28.82134757703144
SECOND BIANCHI
[ 1.27329258e-11  4.09272616e-12  1.09139364e-11 -1.52795110e-10
  1.13686838e-12 -2.54658516e-11  7.27595761e-12 -1.81898940e-11
 -7.27595761e-12  1.09139364e-11]


### The Cross curvature formula:

In [13]:
print(inspect.getsource(Ant.CrossSec))
s1 = Ant.Two(q, Xi1, Xi2)
s2 = Ant.Gamma(q, Xi1, Xi2)
s3 = KMA.Gamma(q, Xi1, Xi2)
print("GAUSS CODAZZI for Cross curvature")
print(s1 + s2 - s3)

print("verifying the simplified formula")
Xiz = jnp.block([Xi1[:n], jnp.zeros(n)])
bXiz = jnp.block([jnp.zeros(n), Xi1[n:]])
s1 = Ant.Curv4(q, Xiz, bXiz, bXiz, Xiz)
s2 = Ant.CrossSec(q, Xi1[:n], Xi1[n:])
print(s1 - s2)


    def CrossSec(self, q, xi, bxi):
        x, y = self.split(q)
        n = self.n
        iLbd = self.iLambda
        Lbd = self.Lambda
        
        x, y = q[:n], q[n:]        
        xmy = (1-jnp.sum(x*(Lbd@y)))
        xILbdy = (1-jnp.sum(x*(iLbd@y)))
        zr = jnp.zeros(n)

        return - 8*jnp.square(self.kmc(q, jnp.block([xi, zr]), jnp.block([zr, bxi]))) \
            + 1/(4*xmy*xILbdy)*jnp.sum(xi*xi)*jnp.sum(bxi*bxi)

GAUSS CODAZZI for Cross curvature
[ 3.60822483e-15 -4.44089210e-16  4.10782519e-15  1.01569415e-14
  2.66453526e-15 -4.44089210e-16  2.22044605e-16 -2.27595720e-15
 -2.22044605e-15  8.88178420e-16]
verifying the simplified formula
5.062616992290714e-14


$\newcommand{\bomg}{\bar{\omega}}$
# THE DIVERGENCE GRAPH and Kim-McCann metric
This is not on the paper, but we illustrate the method based on the references [1] and [2]. We thank the authors [2] for explaining their formulas to us.


Consider the function 
$$c(x, y) = \phi(F(x)) - \phi(y) -\frac{1}{\beta}\log(1+\beta grad_{\phi}(y). (F(x)-y))
$$
which vanishes at $y=F(x)$ and has the form of what is called the $L^{(\beta)}$-divergence [1] This is so that we do not have zero curvature.
* We choose $F$ to be of the following form (entry-wise) to get non zero higher derivatives.
$$F(x) =\frac{kx}{a+\sqrt{b+x^2}}$$

We consider a family of projection
$$\Pi_{\alpha}(\omega, \bomg) = \frac{1}{2}((1+\alpha)\omega + (1-\alpha)F_x^{-1}\bomg, (1+\alpha)F_x\omega + (1-\alpha)\bomg)
$$

which generate a family of connections:
$$\Gamma_{\alpha}(x, (\xi, F_x\xi), (\eta, F_x\eta) )=
\frac{1}{2}((1+\alpha)K^x + (1-\alpha)y_x^{-1}K^y+ (1-\alpha)y_x^{-1}y_{xx}(x;\xi, \eta),\\
(1+\alpha)y_xK^x + (1-\alpha)K^y -(1-\alpha)y_{xx}(x;\xi, \eta)).
$$
With $(K^x, K^y)$ is the ambient connection on $M\times \bar{M}$.

When $\alpha=0$, we have the Levi-Civita connection [2], which can be derived by the method in our paper. When $\alpha \neq 0$, the connection is not metric compatible. 

We demonstrate numerically the affine-Gauss Codazzi equation could be used to compute the curvature, compatible with the curvature formula given by our method.


[1] TK.L. Wong, Logarithmic divergences from optimal transport and rÃľnyi geometry., Information Geometry 1 (2018), 39–78.

[2] F. Léger and F.X. Vialard, A geometric Laplace method, https://arxiv.org/abs/2212.04376, 2022.

In [14]:
n = 3

key = random.PRNGKey(0)    
key, sk = random.split(key)

def F(x):
    a, b, k = (.2, .3, .5)        
    return k*x/(a+jnp.sqrt(b+x*x))

def c(x, y, beta):
    def KL(p):
        return 1/beta*jnp.log(beta*jnp.sum(p*p)+1)

    def gradKL(p):
        return 2*p/(beta*jnp.sum(p*p)+1)
    
    return KL(F(x)) - KL(y) - 1/beta*jnp.log(
        1+beta*jnp.sum(gradKL(y)*(F(x)-y)))

beta = .1
KMAD = KMC(n, c, beta)
KMD = KMCDivergence(KMAD, F)

key, sk = random.split(key)    
q = KMD.randQpoint(sk)


Test projection, metric compatibility, c

In [15]:
key, sk = random.split(key)        
Xi = KMD.randvec(sk, q)

# test projection
key, sk = random.split(key)
Omg = random.normal(sk, (2*n,))

key, sk = random.split(key)
# Omg1 = random.normal(sk, (2*n,))

JF = jacfwd(F)

pOmg = KMD.proj(q, Omg, .3)
print(pOmg[n:] - JF(q[:n])@pOmg[:n])

pOmg = KMD.proj(q, Omg, 0.)
print(
    KMD.kmc(q, KMD.proj(q, Omg, 0.), Xi)
    - KMD.kmc(q, KMD.proj(q, Omg, 0.), Xi))

# test metric compatibility and connectivity
key, sk = random.split(key)        
Xi1 = KMD.randvec(sk, q)

key, sk = random.split(key)        
Xi2 = KMD.randvec(sk, q)

def checkTan(q, Omg):
    return Omg[n:] - JF(q[:n])@Omg[:n]

al1 = -2

print("check covariant derivative is a vector field")
print("Use PI_{-2} for vector field")
covxi2 = jvp(lambda q: KMD.proj(q, Xi2, al1), (q,), (Xi1,))[1] +\
    KMD.Gamma(q, Xi1, Xi2, 0)

print(checkTan(q, covxi2))

s1 = jvp(
    lambda q: KMD.kmc(q, KMD.proj(q, Xi2, al1),
                      KMD.proj(q, Xi2, al1)), (q,), (Xi1,))[1]

s2 = 2*KMD.kmc(q, Xi2, covxi2)
print("check metric compatible of Levi-Civita connection")
print(s1 - s2)


[0.00000000e+00 2.77555756e-17 5.55111512e-17]
0.0
check covariant derivative is a vector field
Use PI_{-2} for vector field
[-1.11022302e-16  1.73472348e-18  4.16333634e-17]
check metric compatible of Levi-Civita connection
-2.220446049250313e-16


In [16]:

# curvature by christoffel function.
key, sk = random.split(key)    
Xi3 = KMD.randvec(sk, q)

key, sk = random.split(key)   
Xi4 = KMD.randvec(sk, q)

print("CHECK THE CURVATURE is tangent to the manifold")
print(checkTan(q, KMD.Curvature(q, Xi1, Xi2, Xi3, al=2.)))

print("CHECK THE BIANCHI identities for the Levi-Civita connection")
# bianchi only for al = 0
print("FIRST BIANCHI")
print(KMD.Curv4(q, Xi1, Xi2, Xi3, Xi4, 0.))
print(KMD.Curv4(q, Xi2, Xi1, Xi3, Xi4, 0.))
print(KMD.Curv4(q, Xi1, Xi2, Xi4, Xi3, 0.))
print(KMD.Curv4(q, Xi3, Xi4, Xi1, Xi2, 0.))

# second Bianchi identity
def nablaR(KMD, q, Xi1, Xi2, Xi3, Xi4, al):
    def xifunc(q, Xi):
        return KMD.proj(q, Xi)

    # extend Xi1, Xi2 to vector fields by projection then take
    # covariant derivatives.
    def cov(q, Xi1, Xi2):
        return jvp(lambda x: KMD.proj(x, Xi2, al), (q,), (Xi1,))[1] \
            + KMD.Gamma(q, Xi1, Xi2, al)

    Tot = jvp(lambda x: KMD.Curvature(
        x,
        KMD.proj(x, Xi2, al),
        KMD.proj(x, Xi3, al),
        KMD.proj(x, Xi4, al), al),
              (q,), (Xi1,))[1] \
        + KMD.Gamma(q, Xi1, KMD.Curvature(q, Xi2, Xi3, Xi4, al), al)
    R12 = KMD.Curvature(q, cov(q, Xi1, Xi2), Xi3, Xi4, al)
    R13 = KMD.Curvature(q, Xi2, cov(q, Xi1, Xi3), Xi4, al)
    R14 = KMD.Curvature(q, Xi2, Xi3, cov(q, Xi1, Xi4), al)
    return Tot - R12 - R13 - R14
print("SECOND BIANCHI")    
f1 = nablaR(KMD, q, Xi1, Xi2, Xi3, Xi4, 0.)
f2 = nablaR(KMD, q, Xi2, Xi3, Xi1, Xi4, 0.)
f3 = nablaR(KMD, q, Xi3, Xi1, Xi2, Xi4, 0.)

print(f1 + f2 + f3)

# Curvature and affine Gauss Codazzi
al = .2
Form4 = KMD.projT(q, KMAD.gkmc(q, Xi4), al)

print(KMD.proj(q, KMD.Two(q, Xi1, Xi2, al), al))
print(KMD.projT(q, KMD.conjTwo(q, Xi1, Form4, al), al))

f1 = KMAD.Curv4(q, Xi1, Xi2, Xi3, Xi4)
f2 = jnp.sum(KMD.Curv4(q, Xi1, Xi2, Xi3, Xi4, al)) \
    + jnp.sum(KMD.Two(q, Xi1, Xi3, al)*KMD.conjTwo(q, Xi2, Form4, al)) \
    - jnp.sum(KMD.Two(q, Xi2, Xi3, al)*KMD.conjTwo(q, Xi1, Form4, al))

print("CURVATURE satisfies the affine Gauss-Codazzi equation")
print(f1 - f2)


CHECK THE CURVATURE is tangent to the manifold
[-2.86229374e-16  4.85722573e-17  4.68375339e-17]
CHECK THE BIANCHI identities for the Levi-Civita connection
FIRST BIANCHI
-0.008386961057168891
0.008386961057168891
0.008386961057168983
-0.008386961057168962
SECOND BIANCHI
[ 6.89032165e-15  7.80625564e-18  2.51534904e-16 -1.20650018e-15
  4.33680869e-18  1.17961196e-16]
[ 2.28983499e-16  8.67361738e-19  1.21430643e-17  3.38271078e-17
 -4.33680869e-19  2.16840434e-18]
[-1.51788304e-18  0.00000000e+00 -1.08420217e-18 -6.93889390e-18
  0.00000000e+00 -2.60208521e-18]
CURVATURE satisfies the affine Gauss-Codazzi equation
5.204170427930421e-17
