<a href="https://colab.research.google.com/github/dnguyend/regularMTW/blob/main/colab/SINHHomogeneousPower.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Potential for cost with $s(u) = p_0e^{-ru} + p_1e^{ru}$ of order $\alpha > 1$
$\newcommand{\tphi}{\underline{\phi}}$
Assume $\tphi$ is absolutely-homogeneous of order $\alpha > 1$ and $grad_{\tphi}$ is invertible. Then $\tphi$ is $c$-convex inside $x^T grad_{\tphi}(x) < \frac{1}{r}$

We verify numerically the duals and the related inequalities


In [1]:
!pip install jaxopt

Collecting jaxopt
  Downloading jaxopt-0.8.2-py3-none-any.whl (170 kB)
[?25l     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m0.0/170.3 kB[0m [31m?[0m eta [36m-:--:--[0m[2K     [91m━━━━━━━━━━━━━━━━[0m[91m╸[0m[90m━━━━━━━━━━━━━━━━━━━━━━━[0m [32m71.7/170.3 kB[0m [31m2.1 MB/s[0m eta [36m0:00:01[0m[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.3/170.3 kB[0m [31m2.7 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxopt
Successfully installed jaxopt-0.8.2


In [4]:
import jaxopt

In [5]:
import jax
import jax.numpy as jnp
import matplotlib.pyplot as plt
pp = jnp.array([0.1648312216190048, -1.2214846630837304, -0.2614981455243139, 1.2214846630837304])
p0 = pp[0]
p2 = pp[2]
r = pp[3]
pw = 3.
cc = jnp.array([1.49862139, 0.05708533, 0.09740697, 0.23343014])

def uf0(s):
    return 1/r*jnp.log((-s+jnp.sqrt(s**2-4*p0*p2))/(2*jnp.abs(p2)))

def uf1(s):
  return -1/(r*jnp.sqrt(s**2-4*p0*p2))

def tPhi(x):
  return jnp.sum(cc*jnp.abs(x)**pw)

def gradtPhi(x):
  return pw*cc*jnp.abs(x)**(pw-1)*jnp.sign(x)

def gradtPhiInv(y):
  return (jnp.abs(y)/(pw*cc))**(1/(pw-1))*jnp.sign(y)

def HessPhi(x, eta):
  return pw*(pw-1)*cc*x**(pw-2)*eta


# Some helper functions
* al_func defines the equation used to define one side of the optimal map, and al_solves it. pw is the vairable for $\alpha$

In [6]:
def grand(key, dims):
  key, sk = jax.random.split(key)
  return jax.random.normal(sk, dims), key

def al_func(t, al, a, b, c):
    return a*t**(2*al) + b*t**(2*al-2) + c

def al_solve(al, a, b, c):
    # solve for the unique positive root of a*t**(al) + b*t**(al-1) + c = 0
    # with al > 1, a >=0, b > 0, c <0
    # change variable a*z**al + b*z**(al-1) + c = 0
    # newton step with some scaling to be in positive region
    # if al = 1 need b+c < 0 and a !=0
    z = 1.
    err = a*z**al + b*z**(al-1) + c
    for i in range(10):
       step = - err/(al*a*z**(al-1) + (al-1)*b*z**(al-2))
       if z+step < 0:
           z = 1e-3
       else:
          z = z + step
       err = a*z**al + b*z**(al-1) + c
       # print(err)
       if abs(err) < 1e-10:
          break
    return jnp.sqrt(z)

def d_a_al_solve(al, a, b, c, t):
  return - t**(2*al)/(2*al*a*t**(2*al-1)+(2*al-2)*b*t**(2*al-3))

ss = al_solve(1.2, 1., 3., -2)
print(al_func(ss, 1.2, 1., 3., -2))


0.0


## Check derivative of a_solve

In [7]:
dlt = 1e-4
ss = al_solve(1.2, 1., 3., -2)
dt = (al_solve(1.2, 1.+dlt, 3., -2) - al_solve(1.2, 1., 3., -2))/dlt
print(dt)
# print(1.*ss**(2*1.2) + 3.*ss**(2*1.2-2) -2)
# print(dlt*ss**(2*1.2) + 1.*2*1.2*ss**(2*1.2-1)*dt + 3.*(2*1.2-2)*ss**(2*1.2-3)*dt)

def da_al_solve(al, a, b, c, t):
  return - t**(2*al)/(2*al*a*t**(2*al-1)+(2*al-2)*b*t**(2*al-3))

d_a_al_solve(1.2, 1, 3., -2, ss)

-0.02503395


Array(-0.02491712, dtype=float32, weak_type=True)

In [8]:

def Lfunc(x, y):
   return uf0(jnp.sum(x*y)) + tPhi(x)

def gradLy(x, y):
    s = jnp.sum(x*y)
    return uf1(s)*y + gradtPhi(x)

def cexp(x, nug):
  """
  """
  xnug = jnp.sum(x*nug)
  return 2*r*jnp.sqrt(-p0*p2)*((1+r*xnug)/(1-r*xnug))**0.5/(1+r*xnug)*nug

def yopt_in_x(x):
  """optimal map
  defined if tPhi(x) < 1/r
  range is F(gradFInv(y)) >= (2*r*(-p0*p2)**.5/pw)**(pw/(pw-1))
  """
  return 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(x)**2)**0.5*gradtPhi(x)


def xopt_in_y(y):
  """inverse optimal map
  range is x gradphi(x) < 1/r
  """
  xdir = gradtPhiInv(y)
  s = jnp.sum(y*xdir)
  t = al_solve(pw, s**2, -4*p0*p2, -1/r**2)
  return t*xdir

Check xopt_y is invert of yopt_x. For invertibility, we need $\tphi(x) < \frac{1}{r\alpha}$

In [15]:
key = jax.random.PRNGKey(0)
n = 4

for i in range(5):
  xa, key = grand(key, (n,))
  xa = .4*xa

  # print(gradtPhi(gradtPhiInv(x)))

  xb = xopt_in_y(yopt_in_x(xa))
  print(xa, xb, xa-xb, tPhi(xa)-1/(r*pw))
  if (not jnp.any(jnp.isnan(xb))) and (not jnp.allclose(xa, xb)):
    print("BAD")
    break


[ 0.45515138 -0.05732573 -0.23661454  0.3178649 ] [ 0.45515135 -0.05732573 -0.23661456  0.31786492] [ 2.9802322e-08  0.0000000e+00  1.4901161e-08 -2.9802322e-08] -0.12278822
[-0.02642903  0.31711265  0.4712014  -0.11868888] [-0.02642903  0.31711265  0.47120136 -0.11868888] [1.8626451e-09 0.0000000e+00 2.9802322e-08 0.0000000e+00] -0.26046276
[-0.88205785 -0.56554955 -0.20421378 -0.06932341] [nan nan nan nan] [nan nan nan nan] 0.7667914
[ 0.16512908  0.49673113 -0.04761466  0.04782562] [ 0.16512907  0.4967311  -0.04761466  0.04782562] [1.4901161e-08 2.9802322e-08 0.0000000e+00 3.7252903e-09] -0.25911146
[-0.34447813 -0.03593322  0.6477703   0.07359286] [-0.34447813 -0.03593322  0.6477703   0.07359286] [ 0.0000000e+00 -3.7252903e-09  0.0000000e+00  7.4505806e-09] -0.18506023


# The Jacobian of xopt_in_y (inverse of the optimal map $T$)

In [17]:
def jac_xopt_in_y(y, eta):
  """inverse optimal map
  range is x gradphi(x) < 1/r
  """
  xdir, dxdir = jax.jvp(gradtPhiInv, (y,), (eta,))
  if jnp.allclose(y, 0):
    dxdir = jnp.zeros(n)
  # xdir = jtmp[0]
  #  dxdir =
  s = jnp.sum(y*xdir)
  t = al_solve(pw, s**2, -4*p0*p2, -1/r**2)
  # print(da_al_solve(pw, s**2, -4*p0*p2, -1/r**2, t))
  dt = da_al_solve(pw, s**2, -4*p0*p2, -1/r**2, t)*2*s*(jnp.sum(eta*xdir)+jnp.sum(y*dxdir))
  # print(da_al_solve(pw, s**2, -4*p0*p2, -1/r**2, t)*2*s)
  # print(dxdir)
  return dt*xdir + t*dxdir

y, key = grand(key, (n,))
eta, key = grand(key, (n,))
print((xopt_in_y(y+dlt*eta) - xopt_in_y(y))/dlt)
jac_xopt_in_y(y, eta)
# print((xopt_in_y(dlt*eta) - xopt_in_y(jnp.zeros(n)))/dlt)
jac_xopt_in_y(jnp.zeros(n), eta)

[-0.04708767 -0.56385994  0.30517578 -0.38206577]


Array([0., 0., 0., 0.], dtype=float32)

In [18]:

x = .1*jnp.array([0.94408732,  0.57970352,  0.08071361, -0.74484489])
print(xopt_in_y(yopt_in_x(x)))
print(x)

# print("1", yopt_in_x(x)/gradtPhi(x))
# print("2", gradtPhi(x))

y1 = cexp(x, gradtPhi(x))
y1a = yopt_in_x(x)
print(y1 - y1a)


[ 0.09440874  0.05797035  0.00807136 -0.07448449]
[ 0.09440874  0.05797035  0.00807136 -0.07448449]
[0. 0. 0. 0.]


Now do cases.
For $\psi(y)$ there is one cases:
* everything corresponds to $x_{opt}$

For $\phi^{cc}(x)$, there are two cases:

* $\phi(x) < \frac{1}{r}$, corresponding to $y = Tx$, $T$ is the optimal map,  $\phi^{cc}(x) < \phi(x)$
* $\phi(x) >= \frac{1}{r}$, no optimal $y$, but a limit point $-u(x^Tz_{opt}) - \psi(z_{opt})$ with $\phi(z) \to \frac{1}{\alpha r}$,
$$\phi^{cc}(x) = \frac{1}{r\alpha }(\log(r\alpha x) + 1).$$

Compute $\psi$ by the formula and then by an optimizer to verify

In [21]:
def psi_opt(y, max_iter=10):
  def L(x, y):
    return uf0(jnp.sum(x*y)) + tPhi(x)
  x0 = .2*jnp.ones(n)
  solver = jaxopt.LBFGS(fun=L, maxiter=max_iter)
  ret = solver.run(x0, y=y)
  return -ret.state.value
def psi(y):
    xopt = xopt_in_y(y)
    return - uf0(jnp.sum(xopt*y)) - tPhi(xopt)


y = jnp.arange(n)

print(psi_opt(y) )

psi(y)

y, _  = grand(key, (n,))
y *= .3

print(psi_opt(y), 10)
print(psi(y))

2.5844443
1.4519486 10
1.4519489


Check the inequality for $\tphi^c$

In [22]:

for i in range(1000):
  x1, _  = grand(key, (n,))
  y, _  = grand(key, (n,))
  s = uf0(jnp.sum(x1*y)) + tPhi(x1) + psi(y)
  if s < 0:
    print("BAD", s)
    break


Now $\tphi^{cc}$, by our formula, and by an optimizer

In [24]:
def phicc(x):
  if tPhi(x) < 1/r:
    return tPhi(x)
  else:
    return 1/(r*pw)*(jnp.log(r*pw*tPhi(x))+1)

def func1(y, x):
    xopt = xopt_in_y(y)
    return uf0(jnp.sum(x*y))  - uf0(jnp.sum(xopt*y)) - tPhi(xopt)

def grad_func1(y, x):
    xopt = xopt_in_y(y)
    return uf1(jnp.sum(x*y))*x  \
    - uf1(jnp.sum(xopt*y))*xopt \
    - uf1(jnp.sum(xopt*y))*jac_xopt_in_y(y, y) \
    - jac_xopt_in_y(y, gradtPhi(xopt))

def phicc_opt(x, N):
  def prox_func(y, x):
      return jax.grad(lambda y: func1(y, x))(y)
  y = jnp.zeros(n)
  c = .5e-1
  for _ in range(N):
    prx = grad_func1(y, x)
    if jnp.any(jnp.isnan(prx)):
      print("NAN", prx, y, x)

    y = y - c*prx
    #print("*", -func(y, x))
    #print("**", uf0(jnp.sum(x*y))  + psi(y))
  return -func1(y, x)



some unused codes

In [None]:
# @title
def grad_func1(y, x):
    xopt = xopt_in_y(y)
    return uf1(jnp.sum(x*y))*x  \
    - uf1(jnp.sum(xopt*y))*xopt \
    - uf1(jnp.sum(xopt*y))*jac_xopt_in_y(y, y) \
    - jac_xopt_in_y(y, gradtPhi(xopt))

dlt = 1e-3
print(jnp.sum(grad_func1(y, x)*eta))
print((func1(y+dlt*eta, x)-func1(y, x))/dlt)
# jac_xopt_in_y(y, eta)
print(phicc(xx1))
phicc_opt(xx1, N=1000)
# grad_func1(jnp.zeros(n), xx1)
# jac_xopt_in_y(jnp.zeros(n), jnp.zeros(n))
# jax.grad(lambda y: func1(y, xx1))(jnp.zeros(n))

Test $\tphi^{cc}$ by the formula is always not worse, and close to value from optimizer

In [26]:
for i in range(1000):
  xx1, key  = grand(key, (n,))
  y, key  = grand(key, (n,))
  s = uf0(jnp.sum(xx1*y)) + tPhi(xx1) + psi(y)
  if s < 0:
    print("BAD", s)
    break
  val1 = phicc(xx1)
  val2 = phicc_opt(xx1, N=200)
  print(val1, val2)
  if val1 < val2:
    print('BD2', val1, val2)
    break


# F(gradFInv(yy1))**(2-2/pw)
# print(psi(yy1))
# 4*r**2*p0*p2/pw**2 + F(gradFInv(yy1))**(2-2/pw)
# 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw)

1.3296468 1.2937222
0.5032124 0.43299294
0.68692267 0.66056335
0.75698197 0.52511364
1.346364 1.3091965
0.6129313 0.59135556
0.4516207 0.40248188
1.1864575 1.1501479
1.1232792 1.0875063
1.1591604 1.1227174
0.6388596 0.48137903
0.9078086 0.89449584
0.11392354 0.11392342
0.58395326 0.57338417
0.8002012 0.7684451
1.176072 1.1396219
0.17890972 0.17532444
0.49291164 0.42205942
0.6104025 0.583619
1.2050675 1.1695542
0.061530784 0.06153086
BD2 0.061530784 0.06153086


This takes long - to numerically verify that $\phi^{cc}$ by the formula is slightly better than $\phi^{cc}$ by the optimizer, the later is actually an unattainable supremum.

In [27]:
for i in range(500):
  xx1, key  = grand(key, (n,))
  y, key  = grand(key, (n,))
  s = uf0(jnp.sum(xx1*y)) + tPhi(xx1) + psi(y)
  if s < 0:
    print("BAD", s)
    break
  val1 = phicc(xx1)
  val2 = phicc_opt(xx1, N=100)
  print(val1, val2)
  if val1 < val2:
    print('BD2', val1, val2)
    if jnp.abs(val1-val2)> 1e-4:
      break


0.96307987 0.937008
0.5361524 0.4397688
0.91269356 0.88650334
0.61805326 0.60334194
0.9600341 0.9101348
0.61280113 0.46255112
0.67384744 0.6469767
0.7583858 0.53124964
1.1069251 1.0595607
0.32683098 0.30527666
1.096282 1.0443336
0.7731905 0.7387986
1.2794398 1.2262731
0.8186334 0.7769756
0.6032765 0.5651525
0.5986661 0.4539407
0.667712 0.48905623
0.28507006 0.27603304
0.3598184 0.33586454
0.93682444 0.9108948
0.17288439 0.17224789
0.13314621 0.13314599
1.1656659 1.1131141
0.23695847 0.23117837
0.78023475 0.73915493
1.406173 1.3531568
0.31260794 0.30447817
0.26200187 0.25674757
0.18309027 0.18309015
0.67154276 0.6283815
0.9587801 0.9217937
0.30688193 0.29865804
0.60813886 0.56989396
1.0661086 1.014945
0.6822185 0.5043715
0.66875106 0.62407255
1.6053097 1.5512254
0.81040204 0.7822519
1.1398015 1.0873744
0.7854798 0.5363257
0.9454259 0.8947475
0.01142715 -0.0025195926
1.1237799 1.0723176
0.5828395 0.44782406
0.06436152 0.06372694
0.78982973 0.7618093
0.7806953 0.5386643
0.41722417 0.37701

# Now do the divergence

In [28]:
def gen_in_cball(key):
  """ generate two points in the ball $x^Tgrad_{\tphi}(x) < \frac{1}{r}
  """
  x, key = grand(key, (n+1,))
  t = x[-1]/jnp.sqrt(1+x[-1]*x[-1])
  return jnp.sum(x[:-1]*gradtPhi(x[:-1]))**(-1/pw)*t*x[:n], key

xx, key = gen_in_cball(key)
xp, key = gen_in_cball(key)
jnp.sum(xp*gradtPhi(xp)), 1/r

(Array(0.1449045, dtype=float32), Array(0.8186759, dtype=float32))

SHow 3 expressions for divergence are the same

In [29]:
def Div(x, xp):
  yp = yopt_in_x(xp)
  return uf0(jnp.sum(x*yp)) + tPhi(x) + psi(yp)

def Div2(x, xp):
  yp = yopt_in_x(xp)
  return uf0(jnp.sum(x*yp)) - uf0(jnp.sum(xp*yp)) + tPhi(x) - tPhi(xp)

def Div3(x, xp):
  # yp = yopt_in_x(xp)
  yp = 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5*gradtPhi(xp)
  xyp = jnp.sum(x*yp)
  xyp = 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5*jnp.sum(x*gradtPhi(xp))
  xpyp = jnp.sum(xp*yp)
  xpyp = 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5*jnp.sum(xp*gradtPhi(xp))

  Xp = xpyp**2-4*p0*p2
  Xp = -4*r**2*p0*p2*jnp.sum(xp*gradtPhi(xp))**2/(1-r**2*pw**2*tPhi(xp)**2)-4*p0*p2
  # Xp = -4*p0*p2*(r**2*pw**2*jnp.sum(tPhi(xp))**2 +1-r**2*pw**2*tPhi(xp)**2)/(1-r**2*pw**2*tPhi(xp)**2)
  Xp = -4*p0*p2/(1-r**2*pw**2*tPhi(xp)**2)
  XYp = xpyp+ jnp.sqrt(Xp)
  XYp = 2*(-p0*p2)**.5*(r*jnp.sum(xp*gradtPhi(xp))+ 1 )/(1-r**2*pw**2*tPhi(xp)**2)**.5
  XxYp = xyp+ jnp.sqrt(xyp**2-4*p0*p2)
  XxYp = 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5*jnp.sum(x*gradtPhi(xp))
  XxYp = 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5*jnp.sum(x*gradtPhi(xp)) \
     + ((2*r*jnp.sum(x*gradtPhi(xp))*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5)**2-4*p0*p2)**.5
  XxYp = 2*r*jnp.sqrt(-p0*p2)/(1-r**2*pw**2*tPhi(xp)**2)**0.5*jnp.sum(x*gradtPhi(xp)) \
     + 2*(-p0*p2)**.5*(r**2*(jnp.sum((x)*gradtPhi(xp))**2
                             - jnp.sum((xp)*gradtPhi(xp))**2)+1)**.5/(1-r**2*pw**2*tPhi(xp)**2)**.5
  XxYp = 2*(-p0*p2)**.5/(1-r**2*pw**2*tPhi(xp)**2)**0.5*(
      r*jnp.sum(x*gradtPhi(xp)) + (r**2*(jnp.sum((x)*gradtPhi(xp))**2
                             - jnp.sum((xp)*gradtPhi(xp))**2)+1)**.5)


  # return  tPhi(x) - tPhi(xp) + uf0(jnp.sum(xyp)) - uf0(jnp.sum(xpyp))
  # return tPhi(x) - tPhi(xp) - 1/r*jnp.log(XxYp/XYp)
  return tPhi(x) - tPhi(xp) - 1/r*jnp.log(
      (
      r*jnp.sum(x*gradtPhi(xp)) + (r**2*(jnp.sum(x*gradtPhi(xp))**2
                             - jnp.sum(xp*gradtPhi(xp))**2)+1)**.5) /
      (r*jnp.sum(xp*gradtPhi(xp))+ 1 ))


print(Div(xx, xp))
print(Div2(xx, xp))
print(Div3(xx, xp))



0.09900358
0.09900357
0.09900355


In [None]:
# @title
def Div_dev(x, xp):
  yp = yopt_in_x(xp)
  ret = uf0(jnp.sum(x*yp)) + tPhi(x) + psi(yp)
  print(ret)
  print(xp - xopt_in_y(yp))

Div_dev(xx, xp)
print(xopt_in_y(yopt_in_x(xopt_in_y(xp))))
print(xopt_in_y(xp))