-
Notifications
You must be signed in to change notification settings - Fork 1
/
fullyconnected.go
72 lines (59 loc) · 1.77 KB
/
fullyconnected.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
package gnn
import (
"math"
"github.com/jacsmith21/gnn/mat"
"github.com/jacsmith21/gnn/rander"
"github.com/jacsmith21/gnn/vec"
)
// FullyConnected a fully connected network layer
type FullyConnected struct {
Weights mat.Matrix
Biases vec.Vector
activations mat.Matrix
sampleCount int
}
// NewFC creates a net fully connected layer and initializes the weights using a Gaussian distribution
func NewFC(in, out int) *FullyConnected {
fc := FullyConnected{
Weights: mat.Make(out, in),
Biases: vec.Make(out),
}
rander.Rand(fc.Weights, rander.Gaussian)
fc.Weights.Scale(1 / math.Sqrt(float64(in)))
return &fc
}
// InitFC initializes a FC layer with the given weights and biases
func InitFC(weights mat.Matrix, biases vec.Vector) *FullyConnected {
if weights.ColCount() == 0 || biases.Len() == 0 {
panic("the weights and biases cannot be empty")
}
return &FullyConnected{
Weights: weights,
Biases: biases,
}
}
// Forward applies the forward operation of a fully connected layer
func (f *FullyConnected) Forward(a mat.Matrix) mat.Matrix {
f.activations = mat.Copy(a) // caching for backpropogation
f.sampleCount = f.activations.ColCount()
a = mat.Mul(f.Weights, a)
a.AddCol(f.Biases)
return a
}
// BackProp applies the backpropogation operation of a fully connected layer and updates its parameters
func (f *FullyConnected) BackProp(t Trainer, dz mat.Matrix) mat.Matrix {
weights := mat.Transpose(f.Weights)
da := mat.Mul(weights, dz)
f.updateParams(t, dz)
return da
}
func (f *FullyConnected) updateParams(t Trainer, dz mat.Matrix) {
f.activations.Transpose()
dw := mat.Mul(dz, f.activations)
db := mat.SumCols(dz)
scale := (1. / float64(f.sampleCount)) * float64(t.LearningRate)
dw.Scale(scale)
db.Scale(scale)
f.Weights.Sub(dw)
f.Biases.Sub(db)
}