-
Notifications
You must be signed in to change notification settings - Fork 0
/
solver.go
95 lines (78 loc) · 2.14 KB
/
solver.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
package training
import math "github.com/chewxy/math32"
// Solver implements an update rule for training a NN
type Solver interface {
Init(size int)
Update(value, gradient float32, iteration, idx int) float32
}
// SGD is stochastic gradient descent with nesterov/momentum
type SGD struct {
lr float32
decay float32
momentum float32
nesterov bool
moments []float32
}
// NewSGD returns a new SGD solver
func NewSGD(lr, momentum, decay float32, nesterov bool) *SGD {
return &SGD{
lr: fparam(lr, 0.01),
decay: decay,
momentum: momentum,
nesterov: nesterov,
}
}
// Init initializes vectors using number of weights in network
func (o *SGD) Init(size int) {
o.moments = make([]float32, size)
}
// Update returns the update for a given weight
func (o *SGD) Update(value, gradient float32, iteration, idx int) float32 {
lr := o.lr / (1 + o.decay*float32(iteration))
o.moments[idx] = o.momentum*o.moments[idx] - lr*gradient
if o.nesterov {
o.moments[idx] = o.momentum*o.moments[idx] - lr*gradient
}
return o.moments[idx]
}
// Adam is an Adam solver
type Adam struct {
lr float32
beta float32
beta2 float32
epsilon float32
v, m []float32
}
// NewAdam returns a new Adam solver
func NewAdam(lr, beta, beta2, epsilon float32) *Adam {
return &Adam{
lr: fparam(lr, 0.001),
beta: fparam(beta, 0.9),
beta2: fparam(beta2, 0.999),
epsilon: fparam(epsilon, 1e-8),
}
}
// Init initializes vectors using number of weights in network
func (o *Adam) Init(size int) {
o.v, o.m = make([]float32, size), make([]float32, size)
}
// Update returns the update for a given weight
func (o *Adam) Update(value, gradient float32, t, idx int) float32 {
lrt := o.lr * (math.Sqrt(1.0 - math.Pow(o.beta2, float32(t)))) /
(1.0 - math.Pow(o.beta, float32(t)))
o.m[idx] = o.beta*o.m[idx] + (1.0-o.beta)*gradient
o.v[idx] = o.beta2*o.v[idx] + (1.0-o.beta2)*math.Pow(gradient, 2.0)
return -lrt * (o.m[idx] / (math.Sqrt(o.v[idx]) + o.epsilon))
}
func fparam(val, fallback float32) float32 {
if val == 0.0 {
return fallback
}
return val
}
func iparam(val, fallback int) int {
if val == 0 {
return fallback
}
return val
}