-
Notifications
You must be signed in to change notification settings - Fork 1
/
seq2seq.go
105 lines (90 loc) · 2.56 KB
/
seq2seq.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
package model
import (
"fmt"
randv2 "math/rand/v2"
"github.com/itsubaki/neu/layer"
"github.com/itsubaki/neu/math/matrix"
"github.com/itsubaki/neu/math/rand"
)
var (
_ decoder = (*Decoder)(nil)
_ decoder = (*PeekyDecoder)(nil)
)
type decoder interface {
Forward(xs []matrix.Matrix, h matrix.Matrix) []matrix.Matrix
Backward(dscore []matrix.Matrix) matrix.Matrix
Generate(h matrix.Matrix, startID, length int) []int
Layers() []TimeLayer
Params() []matrix.Matrix
Grads() []matrix.Matrix
SetParams(p ...matrix.Matrix)
Summary() []string
}
type Seq2Seq struct {
Encoder *Encoder
Decoder decoder
Softmax *layer.TimeSoftmaxWithLoss
Source randv2.Source
}
func NewSeq2Seq(c *RNNLMConfig, s ...randv2.Source) *Seq2Seq {
if len(s) == 0 {
s = append(s, rand.NewSource(rand.MustRead()))
}
return &Seq2Seq{
Encoder: NewEncoder(c, s[0]),
Decoder: NewDecoder(c, s[0]),
Softmax: &layer.TimeSoftmaxWithLoss{},
Source: s[0],
}
}
func (m *Seq2Seq) Forward(xs, ts []matrix.Matrix) []matrix.Matrix {
// xs: ['5', '7', '+', '5', ' ', ' ', ' ']
// dxs: ['_', '6', '2', ' ']
// dts: ['6', '2', ' ', ' ']
dxs, dts := ts[:len(ts)-1], ts[1:] // dxs(4, 128, 1), dts(4, 128, 1)
h := m.Encoder.Forward(xs) // h(128, 128)
score := m.Decoder.Forward(dxs, h) // score(4, 128, 13)
loss := m.Softmax.Forward(score, dts) // (1, 1, 1)
return loss
}
func (m *Seq2Seq) Backward() {
dout := []matrix.Matrix{{{1}}} // (1, 1, 1)
dscore := m.Softmax.Backward(dout) // (4, 128, 13)
dh := m.Decoder.Backward(dscore) // (128, 128)
m.Encoder.Backward(dh) // (0, 0, 0)
}
func (m *Seq2Seq) Generate(xs []matrix.Matrix, startID, length int) []int {
h := m.Encoder.Forward(xs) // xs(7, 1, 1), h(1, 128)
sampeld := m.Decoder.Generate(h, startID, length) //
return sampeld
}
func (m *Seq2Seq) Summary() []string {
s := []string{fmt.Sprintf("%T", m)}
s = append(s, m.Encoder.Summary()...)
s = append(s, m.Decoder.Summary()...)
s = append(s, m.Softmax.String())
return s
}
func (m *Seq2Seq) Layers() []TimeLayer {
layers := make([]TimeLayer, 0)
layers = append(layers, m.Encoder.Layers()...)
layers = append(layers, m.Decoder.Layers()...)
layers = append(layers, m.Softmax)
return layers
}
func (m *Seq2Seq) Params() [][]matrix.Matrix {
return [][]matrix.Matrix{
m.Encoder.Params(),
m.Decoder.Params(),
}
}
func (m *Seq2Seq) Grads() [][]matrix.Matrix {
return [][]matrix.Matrix{
m.Encoder.Grads(),
m.Decoder.Grads(),
}
}
func (m *Seq2Seq) SetParams(p [][]matrix.Matrix) {
m.Encoder.SetParams(p[0]...)
m.Decoder.SetParams(p[1]...)
}