Skip to content
This repository was archived by the owner on Nov 23, 2018. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
161 changes: 33 additions & 128 deletions global.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,17 @@
package optimize

import (
"fmt"
"math"
"sync"
"time"

"github.com/gonum/matrix/mat64"
)

// GlobalMethod is a global optimizer. Typically will require more function
// evaluations and no sense of local convergence
type GlobalMethod interface {
// Global tells method the max number of tasks, method returns how many it wants.
// This is needed to sync the Global goroutines and inside goroutines.
InitGlobal(tasks int) int
InitGlobal(dim, tasks int) int
// Global method may assume that the same task id always has the same pointer with it.
IterateGlobal(task int, loc *Location) (Operation, error)
Needser
Expand Down Expand Up @@ -72,40 +69,19 @@ type GlobalMethod interface {
// Something about Global cannot guarantee strict bounds on function evaluations,
// iterations, etc. in the precense of concurrency.
func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Result, error) {
if p.Func == nil {
panic("optimize: objective function is undefined")
}
if dim <= 0 {
panic("optimize: impossible problem dimension")
}
startTime := time.Now()
if method == nil {
method = &GuessAndCheck{}
}
if err := p.satisfies(method); err != nil {
return nil, err
}
if p.Status != nil {
_, err := p.Status()
if err != nil {
return nil, err
}
}

if settings == nil {
settings = DefaultSettingsGlobal()
}

if settings.Recorder != nil {
// Initialize Recorder first. If it fails, we avoid the (possibly
// time-consuming) evaluation at the starting location.
err := settings.Recorder.Init()
if err != nil {
return nil, err
}
stats := &Stats{}
err := checkOptimization(p, dim, method, settings.Recorder)
if err != nil {
return nil, err
}

stats := &Stats{}
optLoc := newLocation(dim, method)
optLoc.F = math.Inf(1)

Expand All @@ -115,22 +91,21 @@ func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Resul

stats.Runtime = time.Since(startTime)

// Don't need to check convergence because it can't possibly have converged.
// (No function evaluations and no starting location).
var err error
// Send initial location to Recorder
if settings.Recorder != nil {
err = settings.Recorder.Record(optLoc, InitIteration, stats)
// TODO(btracey): Handle this error? Fix when merge with Local.
if err != nil {
return nil, err
}
}

// Run optimization
var status Status
status, err = minimizeGlobal(&p, method, settings, stats, optLoc, startTime)

// Cleanup and collect results
if settings.Recorder != nil && err == nil {
// Send the optimal location to Recorder.
err = settings.Recorder.Record(optLoc, PostIteration, stats)
// TODO(btracey): Handle this error? Fix when merge with Local.
}
stats.Runtime = time.Since(startTime)
return &Result{
Expand All @@ -142,6 +117,7 @@ func Global(p Problem, dim int, settings *Settings, method GlobalMethod) (*Resul

func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *Stats, optLoc *Location, startTime time.Time) (status Status, err error) {
dim := len(optLoc.X)
statuser, _ := method.(Statuser)
gs := &globalStatus{
mux: &sync.RWMutex{},
stats: stats,
Expand All @@ -150,10 +126,11 @@ func minimizeGlobal(p *Problem, method GlobalMethod, settings *Settings, stats *
startTime: startTime,
optLoc: optLoc,
settings: settings,
statuser: statuser,
}

nTasks := settings.Concurrent
nTasks = method.InitGlobal(nTasks)
nTasks = method.InitGlobal(dim, nTasks)

// Launch optimization workers
var wg sync.WaitGroup
Expand All @@ -180,6 +157,7 @@ type globalStatus struct {
optLoc *Location
settings *Settings
method GlobalMethod
statuser Statuser
err error
}

Expand Down Expand Up @@ -209,119 +187,46 @@ func globalWorker(task int, m GlobalMethod, g *globalStatus, loc *Location, x []
// It uses a mutex to protect updates where necessary.
func (g *globalStatus) globalOperation(op Operation, loc *Location, x []float64) Status {
// Do a quick check to see if one of the other workers converged in the meantime.
var status Status
var err error
g.mux.RLock()
s := g.status
status = g.status
g.mux.RUnlock()
if s != NotTerminated {
return s
if status != NotTerminated {
return status
}
switch op {
case NoOperation:
case InitIteration:
panic("optimize: GlobalMethod return InitIteration")
panic("optimize: Method returned InitIteration")
case PostIteration:
panic("optimize: Method returned PostIteration")
case MajorIteration:
g.mux.Lock()
g.stats.MajorIterations++
if loc.F < g.optLoc.F {
copyLocation(g.optLoc, loc)
}
copyLocation(g.optLoc, loc)
g.mux.Unlock()

g.mux.RLock()
status := checkConvergenceGlobal(g.optLoc, g.settings)
status = checkConvergence(g.optLoc, g.settings, false)
g.mux.RUnlock()
if status != NotTerminated {
// Update g.status, preserving the first termination status.
g.mux.Lock()
if g.status == NotTerminated {
g.status = status
}
status = g.status
g.mux.Unlock()
return status
}
default:
if !op.isEvaluation() {
panic(fmt.Sprintf("optimize: invalid evaluation %v", op))
}
copy(x, loc.X)
if op&FuncEvaluation != 0 {
loc.F = g.p.Func(x)
g.mux.Lock()
g.stats.FuncEvaluations++
g.mux.Unlock()
}
if op&GradEvaluation != 0 {
g.p.Grad(loc.Gradient, x)
g.mux.Lock()
g.stats.GradEvaluations++
g.mux.Unlock()
}
if op&HessEvaluation != 0 {
g.p.Hess(loc.Hessian, x)
g.mux.Lock()
g.stats.HessEvaluations++
g.mux.Unlock()
}
default: // Any of the Evaluation operations.
status, err = evaluate(g.p, loc, op, x)
g.mux.Lock()
updateStats(g.stats, op)
g.mux.Unlock()
}

// TODO(btracey): Need to fix all these things to avoid deadlock.
// When re-do, need to make sure aren't overwritting a converged status.
g.mux.Lock()
g.stats.Runtime = time.Since(g.startTime)
if g.settings.Recorder != nil {
err := g.settings.Recorder.Record(loc, op, g.stats)
if err != nil {
if g.status == NotTerminated && g.err != nil {
g.status = Failure
g.err = err
}
}
}
s = checkLimits(loc, g.stats, g.settings)
status, err = iterCleanup(status, err, g.stats, g.settings, g.statuser, g.startTime, loc, op)
// Update the termination status if it hasn't already terminated.
if g.status == NotTerminated {
g.status = s
g.status = status
g.err = err
}
methodStatus, methodIsStatuser := g.method.(Statuser)
if methodIsStatuser {
s, err := methodStatus.Status()
if err != nil && g.status == NotTerminated {
g.status = s
g.err = err
}
}
s = g.status
g.mux.Unlock()
return s
}

func newLocation(dim int, method Needser) *Location {
// TODO(btracey): combine this with Local.
loc := &Location{
X: make([]float64, dim),
}
loc.F = math.Inf(1)
if method.Needs().Gradient {
loc.Gradient = make([]float64, dim)
}
if method.Needs().Hessian {
loc.Hessian = mat64.NewSymDense(dim, nil)
}
return loc
}

func checkConvergenceGlobal(loc *Location, settings *Settings) Status {
if loc.F < settings.FunctionThreshold {
return FunctionThreshold
}
if settings.FunctionConverge != nil {
status := settings.FunctionConverge.FunctionConverged(loc.F)
if status != NotTerminated {
return NotTerminated
}
}
return NotTerminated
return status
}

func DefaultSettingsGlobal() *Settings {
Expand Down
26 changes: 24 additions & 2 deletions guessandcheck.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,14 +4,23 @@

package optimize

import "github.com/gonum/stat/distmv"
import (
"math"
"sync"

"github.com/gonum/stat/distmv"
)

// GuessAndCheck is a global optimizer that evaluates the function at random
// locations. Not a good optimizer, but useful for comparison and debugging.
type GuessAndCheck struct {
Rander distmv.Rander

eval []bool

mux *sync.Mutex
bestF float64
bestX []float64
}

func (g *GuessAndCheck) Needs() struct{ Gradient, Hessian bool } {
Expand All @@ -22,14 +31,27 @@ func (g *GuessAndCheck) Done() {
// No cleanup needed
}

func (g *GuessAndCheck) InitGlobal(tasks int) int {
func (g *GuessAndCheck) InitGlobal(dim, tasks int) int {
g.eval = make([]bool, tasks)
g.bestF = math.Inf(1)
g.bestX = resize(g.bestX, dim)
g.mux = &sync.Mutex{}
return tasks
}

func (g *GuessAndCheck) IterateGlobal(task int, loc *Location) (Operation, error) {
// Task is true if it contains a new function evaluation.
if g.eval[task] {
g.eval[task] = false
g.mux.Lock()
if loc.F < g.bestF {
g.bestF = loc.F
copy(g.bestX, loc.X)
} else {
loc.F = g.bestF
copy(loc.X, g.bestX)
}
g.mux.Unlock()
return MajorIteration, nil
}
g.eval[task] = true
Expand Down
Loading