/
operatorPointwise_binary_const.go
150 lines (129 loc) · 3.18 KB
/
operatorPointwise_binary_const.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
package gorgonia
import "github.com/chewxy/gorgonia/tensor"
var (
/* scalar-tensor float64 and vice versa */
// arith
tadd = denseBinOp(tensor.Add)
tsub = denseBinOp(tensor.Sub)
tmul = denseBinOp(tensor.Mul)
tdiv = denseBinOp(tensor.Div)
tpow = denseBinOp(tensor.Pow)
// cmp
tlt = denseCmpOp(tensor.Lt)
tgt = denseCmpOp(tensor.Gt)
tlte = denseCmpOp(tensor.Lte)
tgte = denseCmpOp(tensor.Gte)
teq = denseCmpOp(tensor.ElEq)
tne = denseCmpOp(tensor.ElNe)
)
type denseBinOp func(a, b interface{}, opts ...tensor.FuncOpt) (tensor.Tensor, error)
type denseCmpOp func(a, b interface{}, opts ...tensor.FuncOpt) (tensor.Tensor, error)
type ʘBinaryOperatorType byte
const (
// arith
addOpType ʘBinaryOperatorType = iota
subOpType
mulOpType
divOpType
powOpType
// cmp
ltOpType
gtOpType
lteOpType
gteOpType
eqOpType
neOpType
maxʘBinaryOpType // delimits the end of all possible binOpType
)
func (op ʘBinaryOperatorType) String() string {
return ʘBinOpStrs[op]
}
// ʘBinOpStrs is the string representation for a binOpType
// It should be held constant.
var ʘBinOpStrs = [maxʘBinaryOpType]string{
// arith ops
"+",
"-",
"⊙",
"÷",
"^",
// cmp ops
"<",
">",
"<=",
">=",
"==",
"!=",
}
// ʘBinOpCommutative is the array that stores whether a binary operator is commutative
// It should be held constant.
var ʘBinOpCommutative = [maxʘBinaryOpType]bool{
true, false, true, false, false,
false, false, false, false, true, true,
}
var ʘBinOpDiffExprs = [maxʘBinaryOpType]func(x, y, z, gradZ *Node) (Nodes, error){
addDiffExpr, subDiffExpr, hadamardProdDiffExpr, hadamardDivDiffExpr, hadamardPowDiffExpr,
nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr, nondiffBinOpExpr,
}
var ʘBinOpDiffFns = [maxʘBinaryOpType]func(x, y, z *Node) error{
addDiff, subDiff, hadamardProdDiff, hadamardDivDiff, hadamardPowDiff,
nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp, nondiffBinOp,
}
// isCommutative gives info about whether the operator is commutative
// For example:
// a + b == b + a
// will ALWAYS evaluate to true. The same cannot be said about subtraction:
// a - b != b - a
// While a-b *may* be equal to b-a, it is not guaranteed. Therefore subtraction
// is not commutative
func (b ʘBinaryOperatorType) isCommutative() bool {
if b >= maxʘBinaryOpType {
panic("isCommutative() for unsupported BinOp undefined")
}
return ʘBinOpCommutative[b]
}
func (b ʘBinaryOperatorType) diffWRT(inputs int) []bool {
if inputs != 2 {
panic("binary operator only supports 2 inputs")
}
if b.isArith() {
return []bool{true, true}
}
return []bool{false, false}
}
// isArith indicates if the binary operator is an arithmetic type
func (b ʘBinaryOperatorType) isArith() bool {
switch b {
case addOpType, subOpType, mulOpType, divOpType, powOpType:
return true
default:
return false
}
return false
}
var binOps = [maxʘBinaryOpType]*denseBinOp{
&tadd,
&tsub,
&tmul,
&tdiv,
&tpow,
nil, // lt
nil, // gt
nil, // lte
nil, // gte
nil, // eq
nil, // ne
}
var cmpOps = [maxʘBinaryOpType]*denseCmpOp{
nil, // add
nil, // sub
nil, // mul
nil, // div
nil, // pow
&tlt,
&tgt,
&tlte,
&tgte,
&teq,
&tne,
}