<a href="https://colab.research.google.com/github/dnguyend/regularMTW/blob/main/colab/SINHHomogeneousOne.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

## Potential for cost with $s(u) = p_0e^{-ru} + p_1e^{ru}$ of order $\alpha = 1$
$\newcommand{\tphi}{\underline{\phi}}$
Assume $\tphi$ is absolutely-homogeneous of order $1$. Assume there is $p$ such that $F=\phi^p$ has invertible gradient, $grad_{F}$ is invertible. Then $\tphi$ is $c$-convex inside $x^T grad_{\tphi}(x) < \frac{1}{r}$

We verify numerically the duals and the related inequalities

In [2]:
import jax.numpy as jnp
import matplotlib.pyplot as plt
pp = jnp.array([0.1648312216190048, -1.2214846630837304, -0.2614981455243139, 1.2214846630837304])
p0 = pp[0]
p2 = pp[2]
r = pp[3]
pw = 3.
cc = jnp.array([1.49862139, 0.05708533, 0.09740697, 0.23343014])

def uf0(s):
    return 1/r*jnp.log((-s+jnp.sqrt(s**2-4*p0*p2))/(2*jnp.abs(p2)))

def uf1(s):
  return -1/(r*jnp.sqrt(s**2-4*p0*p2))

def F(x):
  return jnp.sum(cc*jnp.abs(x)**pw)

def gradF(x):
  return pw*cc*jnp.abs(x)**(pw-1)*jnp.sign(x)

def gradFInv(y):
  return (jnp.abs(y)/(pw*cc))**(1/(pw-1))*jnp.sign(y)

def tPhi(x):
  return F(x)**(1/pw)

def gradtPhi(x):
  return 1/pw*F(x)**(1/pw-1)*gradF(x)

def HessF(x, eta):
  return pw*(pw-1)*cc*x**(pw-2)*eta



def Lfunc(x, y):
   return uf0(jnp.sum(x*y)) + tPhi(x)

def gradLy(x, y):
    s = jnp.sum(x*y)
    return uf1(s)*y + gradtPhi(x)

def cexp(x, nug):
  """
  """
  xnug = jnp.sum(x*nug)
  return 2*r*jnp.sqrt(-p0*p2)*((1+r*xnug)/(1-r*xnug))**0.5/(1+r*xnug)*nug

def yopt_in_x(x):
  """optimal map
  defined if tPhi(x) < 1/r
  range is F(gradFInv(y)) >= (2*r*(-p0*p2)**.5/pw)**(pw/(pw-1))
  """
  return 2*r*jnp.sqrt(-p0*p2)/(1-r**2*tPhi(x)**2)**0.5*gradtPhi(x)


def xopt_in_y(y):
  """inverse optimal map, defined if F(gradFInv(y)) >= (2*r*(-p0*p2)**.5/pw)**(pw/(pw-1))
  range is x gradphi(x) < 1/r
  """
  # return jnp.sqrt(4*r**2*p0*p2*F(gradFInv(pw*y))**(2/pw-2) + 1)/(r*F(gradFInv(y))**(1/pw))*gradFInv(y)
  # return jnp.sqrt(4*p0*p2/pw**2*F(gradFInv(y))**(-2) + 1/(r*F(gradFInv(y))**(2/pw)))*gradFInv(y)
  return jnp.sqrt(4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw))/(r*F(gradFInv(y)))*gradFInv(y)

x = .1*jnp.array([0.94408732,  0.57970352,  0.08071361, -0.74484489])
print(xopt_in_y(yopt_in_x(x)))
print(x)
y1 = cexp(x, gradtPhi(x))
y1a = yopt_in_x(x)
print(y1 - y1a)

def Fyopt_in_x(x):
  # return (2*r*jnp.sqrt(-p0*p2)/(1-r**2*tPhi(x)**2)**0.5)**(pw/(pw-1))*F(gradFInv(gradtPhi(x)))
  return (2*r*jnp.sqrt(-p0*p2)/(pw*(1-r**2*tPhi(x)**2)**0.5))**(pw/(pw-1))

print(F(gradFInv(y1a))- Fyopt_in_x(x))
print(F(gradFInv(y1a)), (2*r*jnp.sqrt(-p0*p2)/(pw*(1-r**2*tPhi(x)**2)**.5))**(pw/(pw-1)))


[ 0.09440917  0.05797062  0.0080714  -0.07448483]
[ 0.09440874  0.05797035  0.00807136 -0.07448449]
[0. 0. 0. 0.]
1.4901161e-08
0.07048926 0.07048924


Now do cases.
For $\psi(y)$ there are 2 cases:
* $F(g_F^{-1}(y)) \leq c_{off}$ corresponding to $x=0$, value $- u(0)$
* $F(g_F^{-1}(y)) > c_{off}$ corresponding to $x_{opt}$

For $\phi^{cc}(x)$, there are three cases:
* $x=0$, then not a single $y_{opt}$, but $c$-subgradient $F(g_F^{-1}(y)) \leq c_{off}$
* $\phi(x) < \frac{1}{r}$, corresponding to $y = Tx$, $T$ is the optimal map,  $\phi^{cc}(x) < \phi(x)$
* $\phi(x) >= \frac{1}{r}$, no optimal $y$, but a limit point $-u(x^Tz_{opt}) - \psi(z_{opt})$ with $\phi(z) \to \frac{1}{r}$,
$$\phi^{cc}(x) = \frac{1}{r}(\log(rx) + 1).$$

In [3]:
!pip install jaxopt

Collecting jaxopt
  Downloading jaxopt-0.8.2-py3-none-any.whl (170 kB)
[2K     [90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━[0m [32m170.3/170.3 kB[0m [31m3.1 MB/s[0m eta [36m0:00:00[0m
Installing collected packages: jaxopt
Successfully installed jaxopt-0.8.2


In [4]:
import jaxopt

In [5]:
import jax
def grand(key, dims):
  key, sk = jax.random.split(key)
  return jax.random.normal(sk, dims), key

In [6]:
def psi_opt(y, max_iter=10):
  def L(x, y):
    return uf0(jnp.sum(x*y)) + tPhi(x)
  x0 = .2*jnp.ones(n)
  solver = jaxopt.LBFGS(fun=L, maxiter=max_iter)
  ret = solver.run(x0, y=y)
  return -ret.state.value


In [7]:
def psi(y):
  if 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw) > 0:
    xopt = jnp.sqrt(4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw))/(r*F(gradFInv(y)))*gradFInv(y)
    return - uf0(jnp.sum(xopt*y)) - tPhi(xopt)
  else:
    return - uf0(0)

n = 4
y = jnp.arange(n)

print(psi_opt(y) )
print(4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw))
psi(y)

2.2288203
7.664603


Array(2.2288098, dtype=float32)

In [8]:
# y = 0.06*jnp.arange(n)
key = jax.random.PRNGKey(0)
y, _  = grand(key, (n,))
y *= .3
print(4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw))
print(-uf0(0))
print(psi_opt(y), 10)
print(psi(y))

0.03794784
0.18891156
0.37703034 10
0.37704134


In [9]:
key = jax.random.PRNGKey(0)
for i in range(1000):
  x1, _  = grand(key, (n,))
  y, _  = grand(key, (n,))
  s = uf0(jnp.sum(x1*y)) + tPhi(x1) + psi(y)
  if s < 0:
    print("BAD", s)
    break


In [10]:
def phicc(x):
  if tPhi(x) < 1/r:
    return tPhi(x)
  else:
    return 1/r*(jnp.log(r*tPhi(x))+1)

In [12]:
def func1(y, x):
    xopt = jnp.sqrt(4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw))/(r*F(gradFInv(y)))*gradFInv(y)
    return uf0(jnp.sum(x*y))  - uf0(jnp.sum(xopt*y)) - tPhi(xopt)

def phicc_opt(x, N):
  def func(y, x):
    if 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw) <= 0:
      return uf0(jnp.sum(x*y))-uf0(0)
    else:
      return func1(y, x)
  def prox_func(y, x):
    if 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw) < 0:
      return uf1(jnp.sum(x*y))*x
    else:
      return jax.grad(lambda y: func1(y, x))(y)
  y = jnp.zeros(n)
  c = .5e-1
  for _ in range(N):
    prx = prox_func(y, x)
    if jnp.any(jnp.isnan(prx)):
      print("NAN", 4*r**2*p0*p2/pw**2 + F(gradFInv(y)), prx, y, x)

    y = y - c*prx
    #print("*", -func(y, x))
    #print("**", uf0(jnp.sum(x*y))  + psi(y))
  return -func(y, x)



This will take long - we want to numerically verify that $\phi^{cc}$ by the formula is better than $\phi^{cc}$ by the optimizer, the later is an unattainable supremum

In [14]:
for i in range(1000):
  xx1, key  = grand(key, (n,))
  y, key  = grand(key, (n,))
  s = uf0(jnp.sum(xx1*y)) + tPhi(xx1) + psi(y)
  if s < 0:
    print("BAD", s)
    break
  val1 = phicc(xx1)
  val2 = phicc_opt(xx1, N=100)
  print(val1, val2)
  if val1 < val2:
    print('BD2', val1, val2)
    break


# F(gradFInv(yy1))**(2-2/pw)
# print(psi(yy1))
# 4*r**2*p0*p2/pw**2 + F(gradFInv(yy1))**(2-2/pw)
# 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw)

1.2150909 1.1828356
1.743231 1.7045631
1.0688471 1.0437515
0.5852026 0.58520246
1.0814492 1.067175
0.947126 0.92645186
1.2065063 1.1813748
1.4557902 1.4196835
0.5507516 0.5507197
0.96760416 0.9470215
1.000813 0.9769477
0.8427013 0.83750945
1.161602 1.1331115
1.5340661 1.4980628
1.680555 1.6603172
1.1030328 1.0745838
1.2366524 1.2227383
0.49700725 0.49700204
1.1093317 1.0808673
0.8976912 0.88406
1.0957927 1.0705225
1.507985 1.4713757
1.032396 1.008733
0.7004567 0.70043653
1.5662924 1.528971
0.3215709 0.32155547
0.64192986 0.64151156
1.0674777 1.0492201
1.068898 1.0429294
1.2311839 1.1993642
0.64428735 0.6440913
0.64267987 0.64256513
1.2261535 1.195493
0.6126525 0.6123899
0.6695124 0.6686831
0.63503283 0.63264555
1.0467614 1.02299
1.3558645 1.3233407
0.9283416 0.918535
1.644448 1.6068482
1.1010556 1.0737374
1.2493585 1.2167606
0.8490762 0.8370606
1.0976745 1.0694612
1.1045749 1.0878326
1.1968932 1.1847637
0.61338145 0.6133794
1.4371601 1.403328
1.00419 0.99385655
1.4788164 1.4421055
1.09

In [None]:
# @title
print(phicc(xx1), tPhi(xx1), 1/r)
def phicc_opt2(x, N):

  def func(y, x):
    if 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw) <= 0:
      return uf0(jnp.sum(x*y))-uf0(0)
    else:
      return func1(y, x)
  def prox_func(y, x):
    if 4*r**2*p0*p2/pw**2 + F(gradFInv(y))**(2-2/pw) < 0:
      return uf1(jnp.sum(x*y))*x
    else:
      return jax.grad(lambda y: func1(y, x))(y)
  y = jnp.zeros(n)
  c = .5e-1
  for _ in range(N):
    prx = prox_func(y, x)
    if jnp.any(jnp.isnan(prx)):
      print("NAN", 4*r**2*p0*p2/pw**2 + F(gradFInv(y)), prx, y, x)

    y = y - c*prx
    #print("*", -func(y, x))
    #print("**", uf0(jnp.sum(x*y))  + psi(y))
  return -func(y, x),y
ret = phicc_opt2(xx1, 100)

In [15]:
for i in range(500):
  xx1, key  = grand(key, (n,))
  y, key  = grand(key, (n,))
  s = uf0(jnp.sum(xx1*y)) + tPhi(xx1) + psi(y)
  if s < 0:
    print("BAD", s)
    break
  val1 = phicc(xx1)
  val2 = phicc_opt(xx1, N=100)
  print(val1, val2)
  if val1 < val2:
    print('BD2', val1, val2)
    if jnp.abs(val1-val2)> 1e-4:
      break


0.6590964 0.65908647
1.4698291 1.433796
1.7342665 1.696888
0.97420913 0.9521086
1.0793352 1.0620325
1.1259992 1.1099954
1.7268472 1.6884737
1.1276463 1.10876
1.106209 1.0789034
1.2597655 1.2325587
1.1324885 1.1132028
1.277241 1.2562809
0.79649717 0.7894986
1.3854764 1.3501797
0.95369977 0.9372428
1.1068441 1.0784357
1.4963436 1.4592998
1.1557436 1.1274066
0.8831485 0.87156594
1.782952 1.7440885
0.9680937 0.9482436
0.94594896 0.9382921
1.0616493 1.0443516
0.9605377 0.9399394
0.48156297 0.47828528
1.5750957 1.5384157
0.65640575 0.6554173
0.65862876 0.6580397
0.5176183 0.5166129
0.49294034 0.49293295
1.4154333 1.3859527
0.5889712 0.58888733
0.6081043 0.6080965
1.3760842 1.3575516
1.4146364 1.3799503
1.4126945 1.3825504
0.21463731 0.21348928
0.7221829 0.7204728
1.2603183 1.2448342
1.1801903 1.1527219
1.486042 1.4493325
0.8039324 0.8009268
0.9672014 0.94702035
1.2443103 1.2118278
0.43670604 0.4367039
0.9893779 0.97764766
1.1505673 1.1280317
1.4349613 1.3992487
1.1232775 1.0978084
0.56674004