forked from owulveryck/onnx-go
/
squeeze.go
104 lines (91 loc) · 2 KB
/
squeeze.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
package gorgonnx
import (
"errors"
"fmt"
"github.com/godshen/onnx-go"
"gorgonia.org/gorgonia"
)
// Specifications: https://github.com/onnx/onnx/blob/master/docs/Operators.md#Squeeze
type squeeze struct {
Axes []int64
}
func init() {
register("Squeeze", newSqueeze)
}
func newSqueeze() operator {
return &squeeze{}
}
func (a *squeeze) apply(g *Graph, ns ...*Node) error {
n := ns[0]
children := getOrderedChildren(g.g, n)
err := checkCondition(children, 1)
if err != nil {
return err
}
tensor := children[0].gorgoniaNode
numAxes := len(a.Axes)
shape := tensor.Shape()
var dims []int
if numAxes == 0 {
// According to the spec, we have to squeeze all axes of single dimensions
for _, v := range shape {
if v == 1 {
numAxes++
}
}
// Scalar, we need to keep at least 1 axe
if numAxes == tensor.Dims() {
dims = []int{1}
} else {
dims = make([]int, tensor.Dims()-numAxes)
index := 0
for _, v := range shape {
if v != 1 {
dims[index] = v
index++
}
}
}
} else {
// Axes to squeeze are specified in the Axes parameter
dims = make([]int, tensor.Dims()-numAxes)
// Make a mask with the axes to remove
mask := make([]bool, tensor.Dims())
for _, v := range a.Axes {
mask[v] = true
}
// If an axis is selected with shape entry not equal to one, an error is raised.
index := 0
for k, v := range shape {
if mask[k] {
if v != 1 {
return fmt.Errorf("Unable to squeeze an axis whose shape entry is not 1 (got %v instead)", v)
}
continue
}
dims[index] = v
index++
}
}
n.gorgoniaNode, err = gorgonia.Reshape(tensor, dims)
return err
}
func (a *squeeze) init(o onnx.Operation) error {
if o.Attributes == nil {
// Use the default Axes attribute
a.Axes = []int64{}
return nil
}
axes := o.Attributes["axes"]
if axes == nil {
// The Axes attribute is optional
a.Axes = []int64{}
return nil
}
var ok bool
a.Axes, ok = axes.([]int64)
if !ok {
return errors.New("squeeze: axes in not an []int64")
}
return nil
}