/
gradient.go
145 lines (132 loc) · 3.69 KB
/
gradient.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
// Copyright ©2017 The Gonum Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package fd
import "gonum.org/v1/gonum/floats"
// Gradient estimates the gradient of the multivariate function f at the
// location x. If dst is not nil, the result will be stored in-place into dst
// and returned, otherwise a new slice will be allocated first. Finite
// difference formula and other options are specified by settings. If settings is
// nil, the gradient will be estimated using the Forward formula and a default
// step size.
//
// Gradient panics if the length of dst and x is not equal, or if the derivative
// order of the formula is not 1.
func Gradient(dst []float64, f func([]float64) float64, x []float64, settings *Settings) []float64 {
if dst == nil {
dst = make([]float64, len(x))
}
if len(dst) != len(x) {
panic("fd: slice length mismatch")
}
// Default settings.
formula := Forward
step := formula.Step
var originValue float64
var originKnown, concurrent bool
// Use user settings if provided.
if settings != nil {
if !settings.Formula.isZero() {
formula = settings.Formula
step = formula.Step
checkFormula(formula)
if formula.Derivative != 1 {
panic(badDerivOrder)
}
}
if settings.Step != 0 {
step = settings.Step
}
originKnown = settings.OriginKnown
originValue = settings.OriginValue
concurrent = settings.Concurrent
}
evals := len(formula.Stencil) * len(x)
nWorkers := computeWorkers(concurrent, evals)
hasOrigin := usesOrigin(formula.Stencil)
// Copy x in case it is modified during the call.
xcopy := make([]float64, len(x))
if hasOrigin && !originKnown {
copy(xcopy, x)
originValue = f(xcopy)
}
if nWorkers == 1 {
for i := range xcopy {
var deriv float64
for _, pt := range formula.Stencil {
if pt.Loc == 0 {
deriv += pt.Coeff * originValue
continue
}
// Copying the data anew has two benefits. First, it
// avoids floating point issues where adding and then
// subtracting the step don't return to the exact same
// location. Secondly, it protects against the function
// modifying the input data.
copy(xcopy, x)
xcopy[i] += pt.Loc * step
deriv += pt.Coeff * f(xcopy)
}
dst[i] = deriv / step
}
return dst
}
sendChan := make(chan fdrun, evals)
ansChan := make(chan fdrun, evals)
quit := make(chan struct{})
defer close(quit)
// Launch workers. Workers receive an index and a step, and compute the answer.
for i := 0; i < nWorkers; i++ {
go func(sendChan <-chan fdrun, ansChan chan<- fdrun, quit <-chan struct{}) {
xcopy := make([]float64, len(x))
for {
select {
case <-quit:
return
case run := <-sendChan:
// See above comment on the copy.
copy(xcopy, x)
xcopy[run.idx] += run.pt.Loc * step
run.result = f(xcopy)
ansChan <- run
}
}
}(sendChan, ansChan, quit)
}
// Launch the distributor. Distributor sends the cases to be computed.
go func(sendChan chan<- fdrun, ansChan chan<- fdrun) {
for i := range x {
for _, pt := range formula.Stencil {
if pt.Loc == 0 {
// Answer already known. Send the answer on the answer channel.
ansChan <- fdrun{
idx: i,
pt: pt,
result: originValue,
}
continue
}
// Answer not known, send the answer to be computed.
sendChan <- fdrun{
idx: i,
pt: pt,
}
}
}
}(sendChan, ansChan)
for i := range dst {
dst[i] = 0
}
// Read in all of the results.
for i := 0; i < evals; i++ {
run := <-ansChan
dst[run.idx] += run.pt.Coeff * run.result
}
floats.Scale(1/step, dst)
return dst
}
type fdrun struct {
idx int
pt Point
result float64
}