Skip to content

Commit

Permalink
Code cleanup
Browse files Browse the repository at this point in the history
  • Loading branch information
fchollet committed Aug 2, 2015
1 parent 82add4a commit c302241
Showing 1 changed file with 15 additions and 14 deletions.
29 changes: 15 additions & 14 deletions nelder_mead.py
Expand Up @@ -5,20 +5,24 @@
Reference: https://en.wikipedia.org/wiki/Nelder%E2%80%93Mead_method
'''

def nelder_mead(f, x_start,
step=0.1, no_improve_thr=10e-6, no_improv_break=10, max_iter=0,
alpha = 1., gamma = 2., rho = -0.5, sigma = 0.5):

def nelder_mead(f, x_start,
step=0.1, no_improve_thr=10e-6,
no_improv_break=10, max_iter=0,
alpha=1., gamma=2., rho=-0.5, sigma=0.5):
'''
@param f (function): function to optimize, must return a scalar score
@param f (function): function to optimize, must return a scalar score
and operate over a numpy array of the same dimensions as x_start
@param x_start (numpy array): initial position
@param step (float): look-around radius in initial step
@no_improv_thr, no_improv_break (float, int): break after no_improv_break iterations with
@no_improv_thr, no_improv_break (float, int): break after no_improv_break iterations with
an improvement lower than no_improv_thr
@max_iter (int): always break after this number of iterations.
Set it to 0 to loop indefinitely.
@alpha, gamma, rho, sigma (floats): parameters of the algorithm
@alpha, gamma, rho, sigma (floats): parameters of the algorithm
(see Wikipedia page for reference)
return: tuple (best parameter array, best score)
'''

# init
Expand All @@ -37,7 +41,7 @@ def nelder_mead(f, x_start,
iters = 0
while 1:
# order
res.sort(key = lambda x: x[1])
res.sort(key=lambda x: x[1])
best = res[0][1]

# break after max_iter
Expand All @@ -53,7 +57,7 @@ def nelder_mead(f, x_start,
prev_best = best
else:
no_improv += 1

if no_improv >= no_improv_break:
return res[0]

Expand Down Expand Up @@ -106,11 +110,8 @@ def nelder_mead(f, x_start,
# test
import math
import numpy as np
def f(x):
return math.sin(x[0])*math.cos(x[1])*(1./(abs(x[2])+1))

print nelder_mead(f, np.array([0.,0.,0.]))


def f(x):
return math.sin(x[0]) * math.cos(x[1]) * (1. / (abs(x[2]) + 1))

print nelder_mead(f, np.array([0., 0., 0.]))

0 comments on commit c302241

Please sign in to comment.