Skip to content

Commit

Permalink
Make reductionInferShape conservative to fix #384 (#411)
Browse files Browse the repository at this point in the history
* Make reductionInferShape conservative to fix #384

The reductionInferShape currently doesn't respect along initially.
It aggressively squeezes dimensions. Not only does it affect normal tensor operation, but also it breaks the backprop autoDiff algorithm sometimes when the network containing BroadcastAdd, resulting in crash when calling Grad().

The change tries to strictly respect the parameter along, e.g.,
(100, 1) along 0, reduction to shape (1) instead ()
(1, 64, 1, 64) along 3 will reduce to (1, 64, 1)
(64, 1, 3, 2) along (2,3) will reduce to (64, 1).

Fixed unit tests.

* Remove inconsistent dimention for Sum op

After changing reductionType to subtract len(along) from reduction op, SumOp's dimension need to be adjusted with it. Otherwise SymDiff will crash for Sum in calcBroadcastShap.
  • Loading branch information
wzzhu committed Jun 15, 2020
1 parent 0153235 commit 15014b3
Show file tree
Hide file tree
Showing 6 changed files with 37 additions and 44 deletions.
4 changes: 2 additions & 2 deletions differentiation.go
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ func forwardDiffAnalysis(outputs, sortedNodes Nodes) (retVal NodeSet, err error)
// diffSet := outputs.Set()
diffSet := outputs.mapSet()

symdiffLogf("Diff Set: %d", diffSet)
symdiffLogf("Diff Set: %v", diffSet)
symdiffLogf("%d", sortedNodes)
// for i := len(sortedNodes) - 1; i ⩾ 0; i-- {
// n := sortedNodes[i]
Expand Down Expand Up @@ -216,7 +216,7 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
// "pullback" function to backpropagate derivatives
activeNodes := affectsOutput.Intersect(affectedByOutput)

symdiffLogf("Active: %d", activeNodes)
symdiffLogf("Active: %v", activeNodes)

symdiffLogf("Sorted: %d", sortedNodes)
symdiffLogf("nodeGradMap: %+#d", FmtNodeMap(nodeGradMap))
Expand Down
9 changes: 7 additions & 2 deletions go.sum
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@ github.com/cznic/strutil v0.0.0-20181122101858-275e90344537/go.mod h1:AHHPPPXTw0
github.com/cznic/xc v0.0.0-20181122101856-45b06973881e/go.mod h1:3oFoiOvCDBYH+swwf5+k/woVmWy7h1Fcyu8Qig/jjX0=
github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
github.com/delaneyj/cogent v0.0.0-20180619184653-2fcea326194c h1:UliKg7JACWAXDW7yFdms6lLwOLK7H3uId3NG5z4f378=
github.com/delaneyj/cogent v0.0.0-20180619184653-2fcea326194c/go.mod h1:hL/k6TDIq37bqQ6sySYVYw+Idnv0JkVmKsmedD5AduQ=
github.com/fatih/color v1.7.0 h1:DkWD4oS2D8LGGgTQ6IvwJJXSL5Vp2ffcQg58nFV38Ys=
github.com/fatih/color v1.7.0/go.mod h1:Zm6kSWBoL9eyXnKyktHP6abPY2pDugNf5KwzbycvMj4=
github.com/fogleman/gg v1.2.1-0.20190220221249-0403632d5b90/go.mod h1:R/bRT+9gY/C5z7JzPU0zXsXHKM4/ayA+zqcVNZzPa1k=
Expand Down Expand Up @@ -50,15 +52,20 @@ github.com/mattn/go-isatty v0.0.8 h1:HLtExJ+uU2HOZ+wI0Tt5DtUDrx8yhUqDcp7fYERX4CE
github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s=
github.com/mattn/go-runewidth v0.0.4 h1:2BvfKmzob6Bmd4YsL0zygOqfdFnK7GR4QL06Do4/p7Y=
github.com/mattn/go-runewidth v0.0.4/go.mod h1:LwmH8dsx7+W8Uxz3IHJYH5QSwggIsqBzpuz5H//U1FU=
github.com/mattn/gorgonia-cfd5423e2acc2f8c2b86 v0.0.0-20200313070349-288c2a647837 h1:/7GLXOx1Cd15DDfNpIZguExr6Ui5e2vKVbCf8x52ls0=
github.com/mattn/gorgonia-cfd5423e2acc2f8c2b86 v0.0.0-20200313070349-288c2a647837/go.mod h1:MGXCds9oIEtiTo7SSDV2qlEYxIFO0LdSOf4BlNJYr34=
github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I=
github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pkg/errors v0.9.1 h1:FEBLx1zS214owpjy7qsBeixbURkuhQAwrK5UwLGTwt4=
github.com/pkg/errors v0.9.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0=
github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df h1:rhEzo7J+sDOLI5NulkwtescnyYMSt4J5mkxDMgQRjN4=
github.com/seehuhn/mt19937 v0.0.0-20191220121156-d07252b9f9df/go.mod h1:w+IAy13Luqfsp+plFpT1RiqauADylJKmpkrWFwpjbsc=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
github.com/stretchr/testify v1.1.4/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.2.1/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
github.com/stretchr/testify v1.4.0 h1:2E4SXV/wtOkTonXsotYi4li6zVWxYlZuYNCXe9XRJyk=
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
Expand Down Expand Up @@ -114,9 +121,7 @@ gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie
gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gorgonia.org/cu v0.9.0-beta h1:s4WQ6fiAGoErwIiXWHRB6Y9ydkx1vTTPwhWzoEZVePc=
gorgonia.org/cu v0.9.0-beta/go.mod h1:RPEPIfaxxqUmeRe7T1T8a0NER+KxBI2McoLEXhP1Vd8=
gorgonia.org/cu v0.9.3 h1:IkxE4NWXuZHqr8AnmgoB8WNQPZeD6u0EJNxYjDC0YgY=
gorgonia.org/cu v0.9.3/go.mod h1:LgyAYDkN7HWhh8orGnCY2R8pP9PYbO44ivEbLMatkVU=
gorgonia.org/dawson v1.1.0 h1:o7+eJ3SKi9sheH19lpOat//tDbg0Y+M9iY/lH79VHqY=
gorgonia.org/dawson v1.1.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws=
gorgonia.org/dawson v1.2.0 h1:hJ/aofhfkReSnJdSMDzypRZ/oWDL1TmeYOauBnXKdFw=
gorgonia.org/dawson v1.2.0/go.mod h1:Px1mcziba8YUBIDsbzGwbKJ11uIblv/zkln4jNrZ9Ws=
Expand Down
10 changes: 6 additions & 4 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,10 +197,11 @@ func WithShape(shp ...int) NodeConsOpt {
s := tensor.Shape(tensor.BorrowInts(len(shp)))
copy(s, shp)
f := func(n *Node) {
if n.t == nil && n.shape == nil {
n.shape = s
return
}
nd := n.Dims()
// if nd == 1 && s.IsVector() {
// goto safe
// }
isVec := s.IsColVec() || s.IsRowVec()
acceptVec := (isVec && (nd == 1))
sameDims := nd == s.Dims()
Expand All @@ -209,7 +210,6 @@ func WithShape(shp ...int) NodeConsOpt {
if !acceptVec && !sameDims && !acceptScalar {
panic(fmt.Sprintf("Node %v, has %d dimensions(Shape: %v). Input shape is %v, which has %d dimensions", n, n.Dims(), n.shape, s, s.Dims()))
}
// safe:
n.shape = s
}
return f
Expand Down Expand Up @@ -258,6 +258,8 @@ func newNode(opts ...NodeConsOpt) *Node {
n := borrowNode()
n.dataOn = CPU
n.id = -1
n.t = nil
n.shape = nil

for _, opt := range opts {
opt(n)
Expand Down
21 changes: 8 additions & 13 deletions op_reduction.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ import (

func reductionType(d int, along []int) hm.Type {
a := hm.TypeVariable('a')
t := makeTensorType(d, a)
t := makeTensorType(d-len(along), a)

axes := make(map[int]bool)
for _, axis := range along {
Expand Down Expand Up @@ -52,24 +52,19 @@ func reductionInferShape(along []int, in tensor.Shape) (tensor.Shape, error) {
if d >= shape.Dims() {
return nil, fmt.Errorf("shape error, along %d is not a valid axis for shape %v", d, in)
}
shape[d] = 1
shape[d] = 0
}
// special cases: if all dimensions are 1 -> ScalarShape, if exactly one dimension is != 1 -> vector
vecD := 0
numNot1 := 0

var dims []int
for _, d := range shape {
if d != 1 {
vecD = d
numNot1++
if numNot1 > 1 {
return shape, nil
}
if d != 0 {
dims = append(dims, d)
}
}
if numNot1 == 0 {
if len(dims) == 0 {
return tensor.ScalarShape(), nil
}
return tensor.Shape{vecD}, nil
return tensor.Shape(dims), nil
}

func reductionDo(op Op, s string, f func(*tensor.Dense, ...int) (*tensor.Dense, error), along []int, inputs ...Value) (retVal Value, err error) {
Expand Down
26 changes: 13 additions & 13 deletions op_reduction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -254,7 +254,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{0},
wantShape: []int{1, 2, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{9, 10, 11, 12, 13, 14, 15, 16},
},
{
Expand All @@ -263,7 +263,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{1},
wantShape: []int{2, 1, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{5, 6, 7, 8, 13, 14, 15, 16},
},
{
Expand All @@ -272,7 +272,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{2},
wantShape: []int{2, 2, 1, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{3, 4, 7, 8, 11, 12, 15, 16},
},
{
Expand All @@ -281,7 +281,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{3},
wantShape: []int{2, 2, 2, 1},
wantShape: []int{2, 2, 2},
wantData: []float32{2, 4, 6, 8, 10, 12, 14, 16},
},
{
Expand All @@ -290,7 +290,7 @@ func TestMaxOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Max,
along: []int{1, 3},
wantShape: []int{2, 1, 2, 1},
wantShape: []int{2, 2},
wantData: []float32{6, 8, 14, 16},
},
{
Expand Down Expand Up @@ -342,7 +342,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{0},
wantShape: []int{1, 2, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{10, 12, 14, 16, 18, 20, 22, 24},
},
{
Expand All @@ -351,7 +351,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{1},
wantShape: []int{2, 1, 2, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{6, 8, 10, 12, 22, 24, 26, 28},
},
{
Expand All @@ -360,7 +360,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{2},
wantShape: []int{2, 2, 1, 2},
wantShape: []int{2, 2, 2},
wantData: []float32{4, 6, 12, 14, 20, 22, 28, 30},
},
{
Expand All @@ -369,7 +369,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{3},
wantShape: []int{2, 2, 2, 1},
wantShape: []int{2, 2, 2},
wantData: []float32{3, 7, 11, 15, 19, 23, 27, 31},
},
{
Expand All @@ -378,7 +378,7 @@ func TestSumOp(t *testing.T) {
inData: []float32{1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16},
op: Sum,
along: []int{1, 3},
wantShape: []int{2, 1, 2, 1},
wantShape: []int{2, 2},
wantData: []float32{14, 22, 46, 54},
},
{
Expand Down Expand Up @@ -562,12 +562,12 @@ func TestFollowupOp(t *testing.T) {
Xn := NewTensor(g, tensor.Float64, 4, WithShape(2, 2, 2, 2), WithInit(RangedFrom(1)))
mx := Must(Max(Xn, 1, 2))
sx := Must(Sum(Xn, 1, 2))
y := NewTensor(g, tensor.Float64, 4, WithShape(2, 1, 1, 2), WithInit(RangedFrom(1)))
y := NewTensor(g, tensor.Float64, 2, WithShape(2, 2), WithInit(RangedFrom(1)))

amx := Must(Add(mx, y))
asx := Must(Add(sx, y))
assert.Equal(t, amx.Shape(), tensor.Shape{2, 1, 1, 2})
assert.Equal(t, asx.Shape(), tensor.Shape{2, 1, 1, 2})
assert.Equal(t, amx.Shape(), tensor.Shape{2, 2})
assert.Equal(t, asx.Shape(), tensor.Shape{2, 2})
vm := NewTapeMachine(g)
defer vm.Close()
err := vm.RunAll()
Expand Down
11 changes: 1 addition & 10 deletions operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -323,16 +323,7 @@ func Sum(a *Node, along ...int) (retVal *Node, err error) {

dims := a.Dims()
if len(along) == 0 {
switch {
case a.IsRowVec():
along = []int{1}
dims = 1
case a.IsColVec(), a.IsVector():
along = []int{0}
dims = 1
default:
along = intRange(0, dims)
}
along = intRange(0, dims)
}

op := newSumOp(along, a.shape, dims)
Expand Down

0 comments on commit 15014b3

Please sign in to comment.