Skip to content
This repository has been archived by the owner on Nov 23, 2018. It is now read-only.

Commit

Permalink
Update NelderMead to new Method interface
Browse files Browse the repository at this point in the history
  • Loading branch information
vladimir-ch committed Sep 17, 2015
1 parent 4600f07 commit 65041af
Show file tree
Hide file tree
Showing 2 changed files with 40 additions and 35 deletions.
73 changes: 40 additions & 33 deletions neldermead.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ const (
nmContractedOutside
nmInitialize
nmShrink
nmMajor
)

type nmVertexSorter struct {
Expand Down Expand Up @@ -134,7 +135,7 @@ func (n *NelderMead) Init(loc *Location) (RequestType, error) {
copy(n.values, n.InitialValues)
sort.Sort(nmVertexSorter{n.vertices, n.values})
computeCentroid(n.vertices, n.centroid)
return n.returnNext(nmReflected, loc.X)
return n.returnNext(nmMajor, loc)
}

// No simplex provided. Begin initializing initial simplex. First simplex
Expand Down Expand Up @@ -166,8 +167,7 @@ func computeCentroid(vertices [][]float64, centroid []float64) {
}

func (n *NelderMead) Iterate(loc *Location) (RequestType, error) {
xNext := loc.X
dim := len(xNext)
dim := len(loc.X)
switch n.lastIter {
case nmInitialize:
n.values[n.fillIdx] = loc.F
Expand All @@ -177,67 +177,75 @@ func (n *NelderMead) Iterate(loc *Location) (RequestType, error) {
// Successfully finished building initial simplex.
sort.Sort(nmVertexSorter{n.vertices, n.values})
computeCentroid(n.vertices, n.centroid)
return n.returnNext(nmReflected, xNext)
return n.returnNext(nmMajor, loc)
}
copy(xNext, n.vertices[dim])
xNext[n.fillIdx] += n.SimplexSize
return FuncEvaluation, nil // InitIteration
copy(loc.X, n.vertices[dim])
loc.X[n.fillIdx] += n.SimplexSize
return FuncEvaluation, nil
case nmMajor:
// Nelder Mead iterations start with Reflection step
return n.returnNext(nmReflected, loc)
case nmReflected:
n.reflectedValue = loc.F
switch {
case loc.F >= n.values[0] && loc.F < n.values[dim-1]:
n.replaceWorst(loc.X, loc.F)
return n.returnNext(nmReflected, xNext)
return n.returnNext(nmMajor, loc)
case loc.F < n.values[0]:
return n.returnNext(nmExpanded, xNext)
return n.returnNext(nmExpanded, loc)
default:
if loc.F < n.values[dim] {
return n.returnNext(nmContractedOutside, xNext)
return n.returnNext(nmContractedOutside, loc)
}
return n.returnNext(nmContractedInside, xNext)
return n.returnNext(nmContractedInside, loc)
}
case nmExpanded:
if loc.F < n.reflectedValue {
n.replaceWorst(loc.X, loc.F)
} else {
n.replaceWorst(n.reflectedPoint, n.reflectedValue)
}
return n.returnNext(nmReflected, xNext)
return n.returnNext(nmMajor, loc)
case nmContractedOutside:
if loc.F <= n.reflectedValue {
n.replaceWorst(loc.X, loc.F)
return n.returnNext(nmReflected, xNext)
return n.returnNext(nmMajor, loc)
}
n.fillIdx = 1
return n.returnNext(nmShrink, xNext)
return n.returnNext(nmShrink, loc)
case nmContractedInside:
if loc.F < n.values[dim] {
n.replaceWorst(loc.X, loc.F)
return n.returnNext(nmReflected, xNext)
return n.returnNext(nmMajor, loc)
}
n.fillIdx = 1
return n.returnNext(nmShrink, xNext)
return n.returnNext(nmShrink, loc)
case nmShrink:
copy(n.vertices[n.fillIdx], loc.X)
n.values[n.fillIdx] = loc.F
n.fillIdx++
if n.fillIdx != dim+1 {
return n.returnNext(nmShrink, xNext)
return n.returnNext(nmShrink, loc)
}
sort.Sort(nmVertexSorter{n.vertices, n.values})
computeCentroid(n.vertices, n.centroid)
return n.returnNext(nmReflected, xNext)
return n.returnNext(nmReflected, loc)
default:
panic("unreachable")
}
}

// returnNext finds the next location to evaluate, stores the location in xNext,
// and returns the data
func (n *NelderMead) returnNext(iter nmIterType, xNext []float64) (RequestType, error) {
dim := len(xNext)
// returnNext updates the location based on the iteration type and the current
// simplex, and returns the next request.
func (n *NelderMead) returnNext(iter nmIterType, loc *Location) (RequestType, error) {
n.lastIter = iter
switch iter {
case nmMajor:
// Fill loc with the current best point and value,
// and request a convergence check.
copy(loc.X, n.vertices[0])
loc.F = n.values[0]
return MajorIteration, nil
case nmReflected, nmExpanded, nmContractedOutside, nmContractedInside:
// x_new = x_centroid + scale * (x_centroid - x_worst)
var scale float64
Expand All @@ -251,21 +259,20 @@ func (n *NelderMead) returnNext(iter nmIterType, xNext []float64) (RequestType,
case nmContractedInside:
scale = -n.contraction
}
floats.SubTo(xNext, n.centroid, n.vertices[dim])
floats.Scale(scale, xNext)
floats.Add(xNext, n.centroid)
dim := len(loc.X)
floats.SubTo(loc.X, n.centroid, n.vertices[dim])
floats.Scale(scale, loc.X)
floats.Add(loc.X, n.centroid)
if iter == nmReflected {
copy(n.reflectedPoint, xNext)
// Nelder Mead iterations start with Reflection step
return FuncEvaluation, nil // MajorIteration
copy(n.reflectedPoint, loc.X)
}
return FuncEvaluation, nil // MinorIteration
return FuncEvaluation, nil
case nmShrink:
// x_shrink = x_best + delta * (x_i + x_best)
floats.SubTo(xNext, n.vertices[n.fillIdx], n.vertices[0])
floats.Scale(n.shrink, xNext)
floats.Add(xNext, n.vertices[0])
return FuncEvaluation, nil // SubIteration
floats.SubTo(loc.X, n.vertices[n.fillIdx], n.vertices[0])
floats.Scale(n.shrink, loc.X)
floats.Add(loc.X, n.vertices[0])
return FuncEvaluation, nil
default:
panic("unreachable")
}
Expand Down
2 changes: 0 additions & 2 deletions unconstrained_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -994,7 +994,6 @@ func newVariablyDimensioned(dim int, gradTol float64) unconstrainedTest {
}

func TestLocal(t *testing.T) {
t.Skip("NelderMead has not yet been updated to work correctly with the new Method interface")
var tests []unconstrainedTest
// Mix of functions with and without Grad() method.
tests = append(tests, gradFreeTests...)
Expand All @@ -1003,7 +1002,6 @@ func TestLocal(t *testing.T) {
}

func TestNelderMead(t *testing.T) {
t.Skip("NelderMead has not yet been updated to work correctly with the new Method interface")
var tests []unconstrainedTest
// Mix of functions with and without Grad() method.
tests = append(tests, gradFreeTests...)
Expand Down

0 comments on commit 65041af

Please sign in to comment.