Skip to content
Permalink
Browse files

Fix reduction bugs (#357)

* Fixed bug for Sum

* added tests
  • Loading branch information
chewxy committed Dec 30, 2019
1 parent a8bd935 commit 1e989517bd089c8ee8c382595ead9881f34e8a83
Showing with 35 additions and 12 deletions.
  1. +9 −11 op_reduction.go
  2. +23 −0 op_reduction_test.go
  3. +2 −0 operations.go
  4. +1 −1 shape_test.go
@@ -43,18 +43,13 @@ func reductionType(d int, along []int) hm.Type {
return hm.NewFnType(t, retType)
}

func reductionInferShape(along []int, inputs ...DimSizer) (tensor.Shape, error) {
if len(inputs) != 1 {
return nil, fmt.Errorf("len(dimsizers)!=1")
}
func reductionInferShape(along []int, in tensor.Shape) (tensor.Shape, error) {
if len(along) == 0 {
return tensor.ScalarShape(), nil
}
in := inputs[0].(tensor.Shape)
shape := make(tensor.Shape, len(in))
copy(shape, in)
shape := in.Clone()
for _, d := range along {
if d >= len(shape) {
if d >= shape.Dims() {
return nil, fmt.Errorf("shape error, along %d is not a valid axis for shape %v", d, in)
}
shape[d] = 1
@@ -126,9 +121,11 @@ func (op maxOp) Type() hm.Type {
return reductionType(op.d, op.along)
}

//func (op maxOp) InferShape(...DimSizer) (tensor.Shape, error) { return scalarShape, nil } // TODO, THIS IS INCORRECT
func (op maxOp) InferShape(dimsizers ...DimSizer) (tensor.Shape, error) {
return reductionInferShape(op.along, dimsizers...)
if len(dimsizers) != 1 {
return nil, errors.Errorf("maxOp only takes one input shape to infer ")
}
return reductionInferShape(op.along, dimsizers[0].(tensor.Shape))
}
func (op maxOp) DiffWRT(i int) []bool { return []bool{true} }

@@ -224,8 +221,9 @@ func (op sumOp) Type() hm.Type {
return reductionType(op.d, op.along)
}

// InferShape infers the shape of a sumOp. It's purpose is to fulfil the Op interface. Only one input is expected, and the type is expected to be a tensor.Shape
func (op sumOp) InferShape(inputs ...DimSizer) (shape tensor.Shape, err error) {
return reductionInferShape(op.along, inputs...)
return reductionInferShape(op.along, inputs[0].(tensor.Shape))
}

func (op sumOp) DiffWRT(i int) []bool { return []bool{true} }
@@ -29,7 +29,30 @@ func TestSumOpGrad(t *testing.T) {
assert.Nil(err)
assert.Equal(1, len(grads))
t.Logf("%v", grads[0])
}

func TestSumOpFakeVec(t *testing.T) {
g := NewGraph()

xv := tensor.New(tensor.WithBacking([]float64{1, 2}), tensor.WithShape(2, 1))
yv := tensor.New(tensor.WithBacking([]float64{10, 20}), tensor.WithShape(1, 2))
x := NewMatrix(g, Float64, WithName("x"), WithShape(2, 1), WithValue(xv))
y := NewMatrix(g, Float64, WithName("y"), WithShape(1, 2), WithValue(yv))
sx, _ := Sum(x)
sy, _ := Sum(y)

assert.True(t, sx.Shape().Eq(tensor.ScalarShape()))
assert.True(t, sy.Shape().Eq(tensor.ScalarShape()))

sx2, _ := Sum(x, 1)
assert.True(t, sx2.Shape().Eq(tensor.Shape{2}))

vm := NewTapeMachine(g)
vm.RunAll()

assert.Equal(t, 3.0, sx.Value().Data(), "Expected sx to be 3.0")
assert.Equal(t, 30.0, sy.Value().Data(), "Expected sy to be 30.0")
assert.Equal(t, []float64{1, 2}, sx2.Value().Data(), "sx2 should be a flat array")
}

func TestSumOpDiff(t *testing.T) {
@@ -327,8 +327,10 @@ func Sum(a *Node, along ...int) (retVal *Node, err error) {
switch {
case a.IsRowVec():
along = []int{1}
dims = 1
case a.IsColVec(), a.IsVector():
along = []int{0}
dims = 1
default:
along = intRange(0, dims)
}
@@ -7,7 +7,7 @@ import (
"gorgonia.org/tensor"
)

func Example_KeepDims() {
func Example_keepDims() {
g := NewGraph()
a := NodeFromAny(g, tensor.New(tensor.WithShape(2, 3), tensor.WithBacking([]float64{1, 2, 3, 4, 5, 6})))
m1, _ := Mean(a, 1)

0 comments on commit 1e98951

Please sign in to comment.
You can’t perform that action at this time.