Skip to content

Commit

Permalink
Fixed a tiny nonconsequential bug in Grad(). Added more tests for Sum
Browse files Browse the repository at this point in the history
  • Loading branch information
chewxy committed Nov 28, 2016
1 parent cf80bd6 commit 7b264e5
Show file tree
Hide file tree
Showing 3 changed files with 124 additions and 113 deletions.
10 changes: 4 additions & 6 deletions gorgonia.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,16 +201,16 @@ func Grad(cost *Node, WRTs ...*Node) (retVal []*Node, err error) {

for i, n := range WRTs {
if !n.isInput() {
errors.Wrapf(err, "Can only differentiate with regards to input nodes. Node %d isn't an input", i)
// return
err = errors.Errorf("Can only differentiate with regards to input nodes. %dth Node %v isn't an input", i, n)
return nil, err
}
}

var dt Dtype
var ok bool
if dt, ok = cost.t.(Dtype); !ok {
errors.Wrap(err, "Expected a scalar dtype for cost")
// return
err = errors.Wrap(err, "Expected a scalar dtype for cost")
return
}

var gradOut *Node
Expand All @@ -236,8 +236,6 @@ func Let(n *Node, be interface{}) (err error) {
}

var val Value
// var t hm.Type
// var dt Dtype
if val, _, _, err = anyToValue(be); err != nil {
return errors.Wrapf(err, anyToValueFail, be, be)
}
Expand Down
201 changes: 101 additions & 100 deletions op_reduction_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,119 +35,115 @@ func TestSumOpDiff(t *testing.T) {
var g, g2 *ExprGraph
var x, y, z, a, b, c *Node
// var x, y, a, b *Node
var xG, aG, bG Value
var xG, yG, aG, bG Value
// var xG, aG Value
var prog *program
var locMap map[*Node]register
var m *tapeMachine
var m2 *lispMachine
var err error

/*
// Basic Test case: a vector is summed
g = NewGraph()
x = NewVector(g, Float64, WithName("x"), WithShape(5), WithInit(RangedFrom(0)))
y = Must(Sum(x))
WithName("y")(y)
Grad(y, x)
prog, locMap, err = Compile(g)
if err != nil {
t.Error(err)
}
ioutil.WriteFile("SumOp.dot", []byte(g.ToDot()), 0644)
m = NewTapeMachine(prog, locMap)
err = m.RunAll()
if err != nil {
t.Error(err)
}
g2 = NewGraph()
a = NewVector(g2, Float64, WithShape(5), WithInit(RangedFrom(0)))
b = Must(Sum(a))
m2 = NewLispMachine(g2)
m2.doWatchAll()
err = m2.RunAll()
if err != nil {
t.Error(err)
}
if aG, err = a.Grad(); err != nil {
t.Error(err)
}
if xG, err = x.Grad(); err != nil {
t.Error(err)
}
assert.True(ValueEq(x.Value(), a.Value()))
assert.True(ValueEq(xG, aG))
assert.True(ValueEq(y.Value(), b.Value()))
// long standing bug: sometimes the derivation will get executed in the machine first
// for example, the deriv of y is 1, and occasionally, the machine will choose to
// execute const 1 into register 0
// It would then fail to bind to y's boundTo, because at that point in time, y is still unknown.
// assert.Equal(y.Grad(), b.Grad())
// Slightly more advanced test case: A matrix is summed
g = NewGraph()
x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
y = Must(Sum(x))
WithName("y")(y)
Grad(y, x)
// var prog *program
prog, locMap, err = Compile(g)
if err != nil {
t.Error(err)
}
m = NewTapeMachine(prog, locMap)
err = m.RunAll()
if err != nil {
t.Error(err)
}
g2 = NewGraph()
a = NewMatrix(g2, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
b = Must(Sum(a))
m2 = NewLispMachine(g2)
err = m2.RunAll()
if err != nil {
t.Error(err)
}
if aG, err = a.Grad(); err != nil {
t.Error(err)
}
if xG, err = x.Grad(); err != nil {
t.Error(err)
}
assert.Equal(x.Value(), a.Value())
assert.Equal(xG, aG)
assert.Equal(y.Value(), b.Value())
*/
// Basic Test case: a vector is summed

g = NewGraph()
x = NewVector(g, Float64, WithName("x"), WithShape(5), WithInit(RangedFrom(0)))
y = Must(Sum(x))
WithName("y")(y)

Grad(y, x)

prog, locMap, err = Compile(g)
if err != nil {
t.Error(err)
}

ioutil.WriteFile("SumOp.dot", []byte(g.ToDot()), 0644)

m = NewTapeMachine(prog, locMap)
err = m.RunAll()
if err != nil {
t.Error(err)
}

g2 = NewGraph()
a = NewVector(g2, Float64, WithShape(5), WithInit(RangedFrom(0)))
b = Must(Sum(a))

m2 = NewLispMachine(g2)
m2.doWatchAll()
err = m2.RunAll()
if err != nil {
t.Error(err)
}

if aG, err = a.Grad(); err != nil {
t.Error(err)
}

if xG, err = x.Grad(); err != nil {
t.Error(err)
}

assert.True(ValueEq(x.Value(), a.Value()))
assert.True(ValueEq(xG, aG))
assert.True(ValueEq(y.Value(), b.Value()))

// long standing bug: sometimes the derivation will get executed in the machine first
// for example, the deriv of y is 1, and occasionally, the machine will choose to
// execute const 1 into register 0
// It would then fail to bind to y's boundTo, because at that point in time, y is still unknown.

// assert.Equal(y.Grad(), b.Grad())

// Slightly more advanced test case: A matrix is summed
g = NewGraph()
x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
y = Must(Sum(x))
WithName("y")(y)

Grad(y, x)
// var prog *program
prog, locMap, err = Compile(g)
if err != nil {
t.Error(err)
}

m = NewTapeMachine(prog, locMap)
err = m.RunAll()
if err != nil {
t.Error(err)
}

g2 = NewGraph()
a = NewMatrix(g2, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
b = Must(Sum(a))

m2 = NewLispMachine(g2)
err = m2.RunAll()
if err != nil {
t.Error(err)
}

if aG, err = a.Grad(); err != nil {
t.Error(err)
}

if xG, err = x.Grad(); err != nil {
t.Error(err)
}

assert.Equal(x.Value(), a.Value())
assert.Equal(xG, aG)
assert.Equal(y.Value(), b.Value())

/* Sum is not the root node */

g = NewGraph()
x = NewMatrix(g, Float64, WithName("x"), WithShape(11, 7), WithInit(RangedFrom(0)))
y = Must(Sum(x))
z = Must(Add(y, twof64))

var grads Nodes
grads, err = Grad(z, x, y)
_, err = Grad(z, x)
if err != nil {
t.Fatal(err)
}
Expand All @@ -156,6 +152,7 @@ func TestSumOpDiff(t *testing.T) {
if err != nil {
t.Error(err)
}
ioutil.WriteFile("Blah.dot", []byte(g.ToDot()), 0644)

m = NewTapeMachine(prog, locMap)
err = m.RunAll()
Expand Down Expand Up @@ -187,10 +184,14 @@ func TestSumOpDiff(t *testing.T) {
t.Error(err)
}

if yG, err = b.Grad(); err != nil {
t.Error(err)
}

assert.Equal(x.Value(), a.Value())
assert.Equal(xG, aG)
assert.Equal(y.Value(), b.Value())
assert.Equal(grads[1].Value(), bG)
assert.Equal(yG, bG)
assert.Equal(z.Value(), c.Value())

}
26 changes: 19 additions & 7 deletions operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -167,7 +167,9 @@ func TestSoftMax(t *testing.T) {
logsm := Must(Neg(Must(Log(sm))))
cost := Must(Slice(logsm, S(2)))

grads, _ := Grad(cost, sm)
if _, err := Grad(cost, x); err != nil {
t.Error(err)
}
prog, locMap, err := Compile(g)
if err != nil {
t.Error(err)
Expand All @@ -178,9 +180,13 @@ func TestSoftMax(t *testing.T) {
if err != nil {
t.Error(err)
}
var smg Value
smg, err = sm.Grad()
if err != nil {

var smg, xG Value
if smg, err = sm.Grad(); err != nil {
t.Error(err)
}

if xG, err = x.Grad(); err != nil {
t.Error(err)
}

Expand All @@ -199,11 +205,17 @@ func TestSoftMax(t *testing.T) {
t.Error(err)
}

smg, err = sm2.Grad()
if err != nil {
var sm2g, x2G Value
if sm2g, err = sm2.Grad(); err != nil {
t.Error(err)
}

if x2G, err = x2.Grad(); err != nil {
t.Error(err)
}
assert.Equal(smg, grads[0].Value())

assert.Equal(smg, sm2g)
assert.Equal(xG, x2G)
}

func TestSlice(t *testing.T) {
Expand Down

0 comments on commit 7b264e5

Please sign in to comment.