Skip to content

Commit

Permalink
Added tests for Lipschitz constant/diag/mat
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffrey-hokanson committed Apr 23, 2019
1 parent 279853c commit 5b4b348
Showing 1 changed file with 29 additions and 18 deletions.
47 changes: 29 additions & 18 deletions tests/test_lipschitz.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
from __future__ import print_function
import numpy as np
import scipy.linalg
from psdr import LipschitzMatrix
from psdr import LipschitzMatrix, LipschitzConstant, DiagonalLipschitzMatrix
from psdr.demos import OTLCircuit

np.random.seed(0)
Expand All @@ -12,34 +12,45 @@ def test_lipschitz_grad(N = 10):
X = func.domain.sample(N)
grads = func.grad(X)

lip = LipschitzMatrix(ftol = 1e-10, gtol = 1e-10, verbose = True)
lip.fit(grads = grads)

lip_mat = LipschitzMatrix(ftol = 1e-10, gtol = 1e-10)
lip_diag = DiagonalLipschitzMatrix(ftol = 1e-10, gtol = 1e-10)
lip_const = LipschitzConstant()


H = np.copy(lip.H)
for lip in [lip_mat, lip_diag, lip_const]:
lip.fit(grads = grads)
H = np.copy(lip.H)

for g in grads:
gap = np.min(scipy.linalg.eigvalsh(H - np.outer(g,g)))
print(gap)
assert gap >= -1e-6
for g in grads:
gap = np.min(scipy.linalg.eigvalsh(H - np.outer(g,g)))
print(gap)
assert gap >= -1e-6

def test_lipschitz_func(M = 20):

func = OTLCircuit()
X = func.domain.sample(M)
fX = func(X)


lip_mat = LipschitzMatrix(ftol = 1e-10, gtol = 1e-10)
lip_diag = DiagonalLipschitzMatrix(ftol = 1e-10, gtol = 1e-10)
lip_const = LipschitzConstant()

for lip in [lip_mat, lip_diag, lip_const]:

lip = LipschitzMatrix(ftol = 1e-10, gtol = 1e-10, verbose = True)
lip.fit(X, fX)

H = np.copy(lip.H)
lip.fit(X, fX)

H = np.copy(lip.H)

for i in range(M):
for j in range(i+1,M):
y = X[i] - X[j]
gap = y.dot(H.dot(y)) - (fX[i] - fX[j])**2
print(gap)
assert gap >= -1e-6

for i in range(M):
for j in range(i+1,M):
y = X[i] - X[j]
gap = y.dot(H.dot(y)) - (fX[i] - fX[j])**2
print(gap)
assert gap >= -1e-6


def test_solver(N = 50, M = 0):
Expand Down

0 comments on commit 5b4b348

Please sign in to comment.