/
denoisingautoencoder.go
116 lines (92 loc) · 1.99 KB
/
denoisingautoencoder.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
package main
import (
"log"
. "gorgonia.org/gorgonia"
)
type DenoisingAutoencoder struct {
Neuron
LayerConfig
h Neuron
af ActivationFunction
input *Node
corruption *Node
corrupted *Node
output *Node
hiddenOut *Node
g *ExprGraph
}
func NewDATiedWeights(w, b *Node, corruption float64, opts ...LayerConsOpt) *DenoisingAutoencoder {
da := new(DenoisingAutoencoder)
da.af = Sigmoid
for _, opt := range opts {
opt(da)
}
da.w = w
da.b = b
da.h.w = Must(Transpose(w))
log.Printf("w %v da.h.w %v", w.Shape(), da.h.w.Shape())
if da.BatchSize == 1 {
da.h.b = NewVector(da.g, dt, WithShape(da.Inputs), WithInit(Zeroes()))
} else {
da.h.b = NewMatrix(da.g, dt, WithShape(da.BatchSize, da.Inputs), WithInit(Zeroes()))
}
da.corruption = BinomialRandomNode(da.g, dt, 1, corruption)
return da
}
func (l *DenoisingAutoencoder) Activate() (retVal *Node, err error) {
if l.output != nil {
return l.output, nil
}
var xw, xwb *Node
if xw, err = Mul(l.corrupted, l.w); err != nil {
return
}
if xwb, err = Add(xw, l.b); err != nil {
return
}
if l.output, err = l.af(xwb); err != nil {
return nil, err
}
return l.output, nil
}
func (l *DenoisingAutoencoder) Reconstruct() (err error) {
if l.hiddenOut != nil {
return nil
}
var yw, ywb *Node
if yw, err = Mul(l.output, l.h.w); err != nil {
return
}
if ywb, err = Add(yw, l.h.b); err != nil {
return
}
if l.hiddenOut, err = l.af(ywb); err != nil {
return
}
return nil
}
func (l *DenoisingAutoencoder) Corrupt() (err error) {
if l.corrupted != nil {
return nil
}
if l.corrupted, err = HadamardProd(l.corruption, l.input); err != nil {
return
}
return nil
}
func (l *DenoisingAutoencoder) Cost(x *Node) (retVal *Node, err error) {
if err = l.Corrupt(); err != nil {
return
}
if _, err = l.Activate(); err != nil {
return
}
if err = l.Reconstruct(); err != nil {
return
}
var loss *Node
if loss, err = BinaryXent(l.hiddenOut, l.input); err != nil {
return
}
return Mean(loss)
}