Skip to content

Commit

Permalink
makes the bfgs operator more natural to work with
Browse files Browse the repository at this point in the history
  • Loading branch information
Niru Maheswaranathan committed Dec 26, 2015
1 parent 348f5b7 commit ab56b6b
Showing 1 changed file with 27 additions and 10 deletions.
37 changes: 27 additions & 10 deletions proxalgs/operators.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def poissreg(x0, rho, x, y):


@curry
def bfgs(x0, rho, f, fgrad):
def bfgs(x0, rho, f_df, maxiter=50, method='BFGS'):
"""
Proximal operator for minimizing an arbitrary function using BFGS
Expand All @@ -140,13 +140,14 @@ def bfgs(x0, rho, f, fgrad):
rho : float
Momentum parameter for the proximal step (larger value -> stays closer to x0)
f : function
The function to use when applying the proximal operator. Must take as input a parameter vector (array_like) and
return a real number (floating point value)
f_df : function
The objective function and gradient
df : function
A function that computes the gradient of `f` with respect to the parameters. Must take as input a parameter
vector (array_like) and returns another ndarray of the same size.
maxiter : int, optional
Maximum number of iterations to take (default: 50)
method : str, optional
Which scipy.optimize algorithm to use (default: 'BFGS')
Returns
-------
Expand All @@ -155,12 +156,28 @@ def bfgs(x0, rho, f, fgrad):
"""

# keep track of the original shape
orig_shape = x0.shape

# specify the objective function and gradient for the proximal operator
g = lambda x: f(x) + (rho / 2) * np.sum((x.reshape(x0.shape) - x0) ** 2)
dg = lambda x: fgrad(x) + rho * (x.reshape(x0.shape) - x0)
def f_df_augmented(x):

xk = x.reshape(orig_shape)

obj, grad = f_df(xk)

g = obj + (rho / 2.) * np.sum((xk - x0) ** 2)
dg = (grad + rho * (xk - x0)).ravel()

return g, dg

# minimize via BFGS
return opt.fmin_bfgs(g, x0, dg, disp=False)
options = {'maxiter': maxiter, 'disp': False}
return opt.minimize(f_df_augmented,
x0.ravel(),
method=method,
jac=True,
options=options).x.reshape(orig_shape)


@curry
Expand Down

0 comments on commit ab56b6b

Please sign in to comment.