forked from aunum/goro
/
flatten.go
76 lines (62 loc) · 1.53 KB
/
flatten.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
package layer
import (
"fmt"
"github.com/aunum/log"
g "github.com/m8u/gorgonia"
)
// Flatten reshapes the incoming tensor to be flat, preserving the batch.
type Flatten struct{}
// ApplyDefaults to the flatten layer.
func (f Flatten) ApplyDefaults() Config { return f }
// Compile the layer.
func (f Flatten) Compile(graph *g.ExprGraph, opts ...CompileOpt) Layer {
flat := newFlatten(&f)
flat.graph = graph
return flat
}
// Clone the config.
func (f Flatten) Clone() Config {
return Flatten{}
}
// Validate the config.
func (f Flatten) Validate() error {
return nil
}
type flatten struct {
*Flatten
graph *g.ExprGraph
}
func newFlatten(config *Flatten) *flatten {
return &flatten{Flatten: config}
}
// Fwd is a forward pass through the layer.
func (f *flatten) Fwd(x *g.Node) (*g.Node, error) {
if len(x.Shape()) < 2 {
return nil, fmt.Errorf("flatten expects input in the shape (batch, x...), to few dimensions in %v", x.Shape())
}
batch := x.Shape()[0]
s := x.Shape()[1:]
product := 1
for _, d := range s {
product *= d
}
newShape := []int{batch, product}
n, err := g.Reshape(x, newShape)
if err != nil {
return nil, err
}
log.Debugf("flatten output shape: %v", n.Shape())
return n, nil
}
// Learnables returns all learnable nodes within this layer.
func (f *flatten) Learnables() g.Nodes {
return g.Nodes{}
}
// Clone the layer.
func (f *flatten) Clone() Layer {
return &flatten{Flatten: f.Flatten.Clone().(*Flatten)}
}
// Graph returns the graph for this layer.
func (f *flatten) Graph() *g.ExprGraph {
return f.graph
}