forked from sugarme/gotch
-
Notifications
You must be signed in to change notification settings - Fork 0
/
sparse.go
54 lines (45 loc) · 1.45 KB
/
sparse.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
package nn
// Sparse layers
import (
"github.com/nullbull/gotch/ts"
)
// Configuration option for an embedding layer.
type EmbeddingConfig struct {
Sparse bool
ScaleGradByFreq bool
WsInit Init
PaddingIdx int64
}
func DefaultEmbeddingConfig() *EmbeddingConfig {
return &EmbeddingConfig{
Sparse: false,
ScaleGradByFreq: false,
WsInit: NewRandnInit(0.0, 1.0),
PaddingIdx: -1,
}
}
// An embedding layer.
//
// An embedding layer acts as a simple lookup table that stores embeddings.
// This is commonly used to store word embeddings.
type Embedding struct {
Ws *ts.Tensor
config *EmbeddingConfig
}
// NewEmbedding creates a new Embedding
func NewEmbedding(vs *Path, numEmbeddings int64, embeddingDim int64, config *EmbeddingConfig) *Embedding {
return &Embedding{
Ws: vs.MustNewVar("weight", []int64{numEmbeddings, embeddingDim}, config.WsInit),
config: config,
}
}
// Implement Module, ModuleT interfaces for Embedding:
// =========================================
// Forward implements Module interface for Embedding
func (e *Embedding) Forward(xs *ts.Tensor) *ts.Tensor {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
}
// ForwardT implements ModuleT interface for Embedding
func (e *Embedding) ForwardT(xs *ts.Tensor, train bool) *ts.Tensor {
return ts.MustEmbedding(e.Ws, xs, e.config.PaddingIdx, e.config.ScaleGradByFreq, e.config.Sparse)
}