Skip to content

Commit

Permalink
Merge pull request #11 from sauln/master
Browse files Browse the repository at this point in the history
handling of defaults for OWL-QN and tests
  • Loading branch information
fgregg committed May 24, 2018
2 parents 4f9ac98 + 4906832 commit 1784414
Show file tree
Hide file tree
Showing 2 changed files with 62 additions and 2 deletions.
15 changes: 14 additions & 1 deletion lbfgs/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
Python wrapper around liblbfgs.
"""

import warnings
from ._lowlevel import LBFGS, LBFGSError


Expand All @@ -21,6 +22,7 @@ def fmin_lbfgs(f, x0, progress=None, args=(), orthantwise_c=0,
Called with the current position x (a numpy.ndarray), a gradient
vector g (a numpy.ndarray) to be filled in and *args.
Must return the value at x and set the gradient vector g.
x0 : array-like
Initial values. A copy of this array is made prior to optimization.
Expand All @@ -47,7 +49,10 @@ def fmin_lbfgs(f, x0, progress=None, args=(), orthantwise_c=0,
zero, the library modifies function and gradient evaluations from
a client program suitably; a client program thus have only to return
the function value F(x) and gradients G(x) as usual. The default value
is zero.
is zero.
If orthantwise_c is set, then line_search cannot be the default
and must be one of 'armijo', 'wolfe', or 'strongwolfe'.
orthantwise_start: int, optional (default=0)
Start index for computing L1 norm of the variables.
Expand Down Expand Up @@ -168,6 +173,14 @@ def fmin_lbfgs(f, x0, progress=None, args=(), orthantwise_c=0,
"""

# Some input validation to make sure defaults with OWL-QN are adapted correctly
assert orthantwise_c >= 0, "Orthantwise_c cannot be negative"

if orthantwise_c > 0 and line_search != 'wolfe':
line_search = 'wolfe'
warnings.warn("When using OWL-QN, 'wolfe' is the only valid line_search. line_search has been set to 'wolfe'.")

opt = LBFGS()
opt.orthantwise_c = orthantwise_c
opt.orthantwise_start = orthantwise_start
Expand Down
49 changes: 48 additions & 1 deletion tests/test_lbfgs.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,56 @@ def f(x, g, *args):
g[0] = 2 * x
return x ** 2

xmin = fmin_lbfgs(f, 100.)
xmin = fmin_lbfgs(f, 100., line_search='armijo')
assert_array_equal(xmin, [0])

xmin = fmin_lbfgs(f, 100., line_search='strongwolfe')
assert_array_equal(xmin, [0])

class TestOWLQN:

def test_owl_qn(self):
def f(x, g, *args):
g[0] = 2 * x
return x ** 2

xmin = fmin_lbfgs(f, 100., orthantwise_c=1, line_search='wolfe')
assert_array_equal(xmin, [0])

def test_owl_line_search_default(self):
def f(x, g, *args):
g[0] = 2 * x
return x ** 2

with pytest.warns(UserWarning, match="OWL-QN"):
xmin = fmin_lbfgs(f, 100., orthantwise_c=1)

def test_owl_line_search_warning_explicit(self):
def f(x, g, *args):
g[0] = 2 * x
return x ** 2

with pytest.warns(UserWarning, match="OWL-QN"):
xmin = fmin_lbfgs(f, 100., orthantwise_c=1, line_search='default')
with pytest.warns(UserWarning, match="OWL-QN"):
xmin = fmin_lbfgs(f, 100., orthantwise_c=1, line_search='morethuente')
with pytest.warns(UserWarning, match="OWL-QN"):
xmin = fmin_lbfgs(f, 100., orthantwise_c=1, line_search='armijo')
with pytest.warns(UserWarning, match="OWL-QN"):
xmin = fmin_lbfgs(f, 100., orthantwise_c=1, line_search='strongwolfe')

@pytest.mark.xfail(strict=True)
def test_owl_wolfe_no_warning(self):
""" This test is an attempt to show that wolfe throws no warnings.
"""

def f(x, g, *args):
g[0] = 2 * x
return x ** 2

with pytest.warns(UserWarning, match="OWL-QN"):
xmin = fmin_lbfgs(f, 100., orthantwise_c=1, line_search='wolfe')


def test_2d():
def f(x, g, f_calls):
Expand Down

0 comments on commit 1784414

Please sign in to comment.