-
Notifications
You must be signed in to change notification settings - Fork 9
Stop modifying Location beyond requested evaluations #131
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -21,11 +21,9 @@ type LinesearchMethod struct { | |
x []float64 // Starting point for the current iteration. | ||
dir []float64 // Search direction for the current iteration. | ||
|
||
first bool // Indicator of the first iteration. | ||
nextMajor bool // Indicates that MajorIteration must be requested at the next call to Iterate. | ||
|
||
loc Location // Storage for intermediate locations. | ||
eval Operation // Indicator of valid fields in loc. | ||
first bool // Indicator of the first iteration. | ||
nextMajor bool // Indicates that MajorIteration must be commanded at the next call to Iterate. | ||
eval Operation // Indicator of valid fields in Location. | ||
|
||
lastStep float64 // Step taken from x in the previous call to Iterate. | ||
lastOp Operation // Operation returned from the previous call to Iterate. | ||
|
@@ -43,87 +41,78 @@ func (ls *LinesearchMethod) Init(loc *Location) (Operation, error) { | |
ls.first = true | ||
ls.nextMajor = false | ||
|
||
copyLocation(&ls.loc, loc) | ||
// Indicate that all fields of ls.loc are valid. | ||
// Indicate that all fields of loc are valid. | ||
ls.eval = FuncEvaluation | GradEvaluation | ||
if ls.loc.Hessian != nil { | ||
if loc.Hessian != nil { | ||
ls.eval |= HessEvaluation | ||
} | ||
|
||
ls.lastStep = math.NaN() | ||
ls.lastOp = NoOperation | ||
|
||
return ls.initNextLinesearch(loc.X) | ||
return ls.initNextLinesearch(loc) | ||
} | ||
|
||
func (ls *LinesearchMethod) Iterate(loc *Location) (Operation, error) { | ||
switch ls.lastOp { | ||
case NoOperation: | ||
// TODO(vladimir-ch): We have previously returned with an error and | ||
// Init was not called. What to do? What about ls's internal state? | ||
// TODO(vladimir-ch): Either Init has not been called, or the caller is | ||
// trying to resume the optimization run after Iterate previously | ||
// returned with an error. Decide what is the proper thing to do. See also #125. | ||
|
||
case MajorIteration: | ||
// We previously requested MajorIteration but since we're here, the | ||
// previous location was not good enough to converge the full | ||
// optimization. Start the next linesearch and store the next | ||
// evaluation point in loc.X. | ||
return ls.initNextLinesearch(loc.X) | ||
// The previous updated location did not converge the full | ||
// optimization. Initialize a new Linesearch. | ||
return ls.initNextLinesearch(loc) | ||
|
||
default: | ||
// Store the result of the previously requested evaluation into ls.loc. | ||
if ls.lastOp&FuncEvaluation != 0 { | ||
ls.loc.F = loc.F | ||
} | ||
if ls.lastOp&GradEvaluation != 0 { | ||
copy(ls.loc.Gradient, loc.Gradient) | ||
} | ||
if ls.lastOp&HessEvaluation != 0 { | ||
ls.loc.Hessian.CopySym(loc.Hessian) | ||
} | ||
// Update the indicator of valid fields of ls.loc. | ||
// Update the indicator of valid fields of loc. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Move the indicator below the nextMajor check. It's not needed there, so it obfuscates. There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Although not necessary, I would like |
||
ls.eval |= ls.lastOp | ||
|
||
if ls.nextMajor { | ||
ls.nextMajor = false | ||
|
||
// Linesearcher previously indicated that it had finished, but we | ||
// needed to evaluate invalid fields of ls.loc. Now we have them and | ||
// can announce MajorIteration. | ||
|
||
copyLocation(loc, &ls.loc) | ||
// Linesearcher previously finished, and the invalid fields of loc | ||
// have now been validated. Announce MajorIteration. | ||
ls.lastOp = MajorIteration | ||
return ls.lastOp, nil | ||
} | ||
} | ||
|
||
projGrad := floats.Dot(ls.loc.Gradient, ls.dir) | ||
if ls.Linesearcher.Finished(ls.loc.F, projGrad) { | ||
// Form an operation that evaluates invalid fields of ls.loc. | ||
ls.lastOp = complementEval(&ls.loc, ls.eval) | ||
// Continue the linesearch. | ||
|
||
f := math.NaN() | ||
if ls.eval&FuncEvaluation != 0 { | ||
f = loc.F | ||
} | ||
projGrad := math.NaN() | ||
if ls.eval&GradEvaluation != 0 { | ||
projGrad = floats.Dot(loc.Gradient, ls.dir) | ||
} | ||
|
||
if ls.Linesearcher.Finished(f, projGrad) { | ||
// Form an operation that evaluates invalid fields of loc. | ||
ls.lastOp = complementEval(loc, ls.eval) | ||
if ls.lastOp == NoOperation { | ||
// ls.loc is complete and MajorIteration can be announced directly. | ||
copyLocation(loc, &ls.loc) | ||
// loc is complete and MajorIteration can be announced directly. | ||
ls.lastOp = MajorIteration | ||
} else { | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Could you add a comment here like "Continue the linesearch" just to confirm to the code reader that this is the main line? (followed by the update valid fields) |
||
ls.nextMajor = true | ||
} | ||
return ls.lastOp, nil | ||
} | ||
|
||
step, op, err := ls.Linesearcher.Iterate(ls.loc.F, projGrad) | ||
step, op, err := ls.Linesearcher.Iterate(f, projGrad) | ||
if err != nil { | ||
return ls.error(err) | ||
} | ||
if !op.isEvaluation() { | ||
panic("linesearch: Linesearcher returned invalid operation") | ||
} | ||
|
||
if step == ls.lastStep { | ||
// Linesearcher is requesting another evaluation at the same point | ||
// which is stored in ls.loc.X. | ||
copy(loc.X, ls.loc.X) | ||
} else { | ||
// We are moving to a new location. | ||
if step != ls.lastStep { | ||
// We are moving to a new location, and not, say, evaluating extra | ||
// information at the current location. | ||
|
||
// Compute the next evaluation point and store it in loc.X. | ||
floats.AddScaledTo(loc.X, ls.x, step, ls.dir) | ||
|
@@ -135,8 +124,7 @@ func (ls *LinesearchMethod) Iterate(loc *Location) (Operation, error) { | |
} | ||
|
||
ls.lastStep = step | ||
copy(ls.loc.X, loc.X) // Move ls.loc to the next evaluation point | ||
ls.eval = NoOperation // and invalidate all its fields. | ||
ls.eval = NoOperation // Indicate all invalid fields of loc. | ||
} | ||
|
||
ls.lastOp = op | ||
|
@@ -149,40 +137,39 @@ func (ls *LinesearchMethod) error(err error) (Operation, error) { | |
} | ||
|
||
// initNextLinesearch initializes the next linesearch using the previous | ||
// complete location stored in ls.loc. It fills xNext and returns an evaluation | ||
// to be performed at xNext. | ||
func (ls *LinesearchMethod) initNextLinesearch(xNext []float64) (Operation, error) { | ||
copy(ls.x, ls.loc.X) | ||
// complete location stored in loc. It fills loc.X and returns an evaluation | ||
// to be performed at loc.X. | ||
func (ls *LinesearchMethod) initNextLinesearch(loc *Location) (Operation, error) { | ||
copy(ls.x, loc.X) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Do we need to set ls.lastOp to be NoOperation up here so that's what it equals if an error occurs? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. On error I call There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Oh, sorry, I missed that. Thanks. |
||
|
||
var step float64 | ||
if ls.first { | ||
ls.first = false | ||
step = ls.NextDirectioner.InitDirection(&ls.loc, ls.dir) | ||
step = ls.NextDirectioner.InitDirection(loc, ls.dir) | ||
} else { | ||
step = ls.NextDirectioner.NextDirection(&ls.loc, ls.dir) | ||
step = ls.NextDirectioner.NextDirection(loc, ls.dir) | ||
} | ||
|
||
projGrad := floats.Dot(ls.loc.Gradient, ls.dir) | ||
projGrad := floats.Dot(loc.Gradient, ls.dir) | ||
if projGrad >= 0 { | ||
return ls.error(ErrNonNegativeStepDirection) | ||
} | ||
|
||
op := ls.Linesearcher.Init(ls.loc.F, projGrad, step) | ||
op := ls.Linesearcher.Init(loc.F, projGrad, step) | ||
if !op.isEvaluation() { | ||
panic("linesearch: Linesearcher returned invalid operation") | ||
} | ||
|
||
floats.AddScaledTo(xNext, ls.x, step, ls.dir) | ||
if floats.Equal(ls.x, xNext) { | ||
floats.AddScaledTo(loc.X, ls.x, step, ls.dir) | ||
if floats.Equal(ls.x, loc.X) { | ||
// Step size is so small that the next evaluation point is | ||
// indistinguishable from the starting point for the current iteration | ||
// due to rounding errors. | ||
return ls.error(ErrNoProgress) | ||
} | ||
|
||
ls.lastStep = step | ||
copy(ls.loc.X, xNext) // Move ls.loc to the next evaluation point | ||
ls.eval = NoOperation // and invalidate all its fields. | ||
ls.eval = NoOperation // Invalidate all fields of loc. | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. "Indicate all invalid fields of loc" to mirror the comment in Init. |
||
|
||
ls.lastOp = op | ||
return ls.lastOp, nil | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds good, we'll leave to another PR for fix.