Skip to content

Commit

Permalink
Issue92fix (#121)
Browse files Browse the repository at this point in the history
* Fixes #92

* yesterday's commit didn't actually fix #92 because preallocdo was not used. Now that it's used, there is a whole host of other problems

* Fixed so that DoDiff uses PreAllocDo

* Updated some performance related code to make sure the program allocates even less
  • Loading branch information
chewxy committed Jun 15, 2017
1 parent 83efa27 commit a9403d5
Show file tree
Hide file tree
Showing 8 changed files with 92 additions and 23 deletions.
2 changes: 1 addition & 1 deletion collections.go
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ func (ns Nodes) remove(what *Node) Nodes {
}

func (ns Nodes) dimSizers() []DimSizer {
retVal := make([]DimSizer, len(ns))
retVal := borrowDimSizers(len(ns))
for i, n := range ns {
if s, ok := n.op.(sizeOp); ok {
retVal[i] = s
Expand Down
56 changes: 47 additions & 9 deletions op_tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,8 @@ func newRepeatOp(along axes, children Nodes) *repeatOp {
arg0Dim: children[0].Dims(),
}

if s, err := retVal.InferShape(children.dimSizers()...); err == nil {
ds := children.dimSizers()
if s, err := retVal.InferShape(ds...); err == nil {
retVal.inputShape = s
if s.IsColVec() {
retVal.d = 1
Expand All @@ -199,6 +200,7 @@ func newRepeatOp(along axes, children Nodes) *repeatOp {
} else {
panic(err)
}
returnDimSizers(ds)

return retVal
}
Expand Down Expand Up @@ -580,16 +582,16 @@ func (op *sliceOp) DoDiff(ctx ExecutionContext, inputs Nodes, output *Node) (err
ydv := output.boundTo.(*dualValue)
incrOp := sliceIncrOp{op}

var d Value
if d, err = incrOp.Do(xdv.Value, ydv.d); err != nil {
// var d Value
if _, err = incrOp.UsePreallocDo(xdv.d, xdv.d, ydv.d); err != nil {
return errors.Wrapf(err, doFail, incrOp)
}

// there is no need to handle scalars, because you can never slice a scalar
add := newElemBinOp(addOpType, inputs[0], output)
if _, err = add.UnsafeDo(xdv.d, d); err != nil {
return errors.Wrapf(err, unsafeDoFail, add)
}
// add := newElemBinOp(addOpType, inputs[0], output)
// if _, err = add.UnsafeDo(xdv.d, d); err != nil {
// return errors.Wrapf(err, unsafeDoFail, add)
// }

return
}
Expand Down Expand Up @@ -722,6 +724,7 @@ func (op sliceIncrOp) SymDiff(inputs Nodes, outputNode, gradNode *Node) (retVal
return nil, errors.Wrap(err, operationError)
}
retVal = Nodes{gradNode, slicedRes}

return
}

Expand Down Expand Up @@ -792,9 +795,44 @@ func (op sliceIncrOp) Do(inputs ...Value) (retVal Value, err error) {
return
}

// func (op sliceIncrOp) usePreallocDoer(prealloc Value, inputs ...Value) (retVal Value, err error) {
func (op sliceIncrOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value, err error) {
machineLogf("Doing %v", op)
enterLoggingContext()
defer leaveLoggingContext()

// }
if err = checkArity(op, len(inputs)); err != nil {
return
}
incr := inputs[1]

// prep the slices
slices := make([]tensor.Slice, op.d)
if !op.all() {
slices[op.along] = op
}

switch T := prealloc.(type) {
case *tensor.Dense:
var v tensor.Tensor
if v, err = T.Slice(slices...); err != nil {
return nil, errors.Wrapf(err, sliceFail, slices)
}
switch i := incr.(type) {
case *F64:
tensor.Add(v, i.any(), tensor.UseUnsafe())
case *F32:
tensor.Add(v, i.any(), tensor.UseUnsafe())
case *tensor.Dense:
tensor.Add(v, i, tensor.UseUnsafe())
}
retVal = T
case Scalar:
return nil, errors.New("Cannot slice a scalar value")
default:
return nil, errors.Errorf(nyiFail, "sliceIncrOp()", prealloc)
}
return
}

func (op sliceIncrOp) OverwritesInput() int { return 0 }

Expand Down
4 changes: 3 additions & 1 deletion operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,16 @@ func applyOp(op Op, children ...*Node) (retVal *Node, err error) {
return
}

ds := Nodes(children).dimSizers()
var s tensor.Shape
if s, err = op.InferShape(Nodes(children).dimSizers()...); err == nil {
if s, err = op.InferShape(ds...); err == nil {
shapeLogf("inferred shape %v", s)
retVal = NewUniqueNode(WithType(retType), WithOp(op), WithChildren(children), In(g), WithShape(s...))
} else {
err = errors.Wrapf(err, "Failed to infer shape. Op: %v", op)
// retVal = newUniqueNode(withType(retType), withOp(op), withChildren(children), withGraph(g))
}
returnDimSizers(ds)
return
}

Expand Down
9 changes: 4 additions & 5 deletions operations_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -235,7 +235,6 @@ var gtTests = []struct {
func TestGt(t *testing.T) {
defer runtime.GC()
for i, gtts := range gtTests {
log.Printf("i %d", i)
// if i != 11 {
// continue
// }
Expand Down Expand Up @@ -464,10 +463,10 @@ func TestSlice(t *testing.T) {

sV := sliced.Value()
if !sts.expected.Eq(sV.Shape()) {
t.Errorf("Test %q. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, sV.Shape())
t.Errorf("Test %q For TapeMachine. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, sV.Shape())
}

assert.Equal(t, sts.data, sV.Data(), "Test %q data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, sV.Data(), sV)
assert.Equal(t, sts.data, sV.Data(), "Test %q For TapeMachine data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, sV.Data(), sV)

// Test Lisp Machine for equivalence of gradients

Expand All @@ -484,10 +483,10 @@ func TestSlice(t *testing.T) {

s2V := sliced2.Value()
if !sts.expected.Eq(s2V.Shape()) {
t.Errorf("Test %q. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, s2V.Shape())
t.Errorf("Test %q For LispMachine. Expected sliced value to have the shape %v. Got %v instead", sts.name, sts.expected, s2V.Shape())
}

assert.Equal(t, sts.data, s2V.Data(), "Test %q data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, s2V.Data(), s2V)
assert.Equal(t, sts.data, s2V.Data(), "Test %q For TapeMachine data expected %v, Got %v instead. Formatted:\n %+v", sts.name, sts.data, s2V.Data(), s2V)

sG, err := sliced.Grad()
if err != nil {
Expand Down
25 changes: 25 additions & 0 deletions perf.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,3 +71,28 @@ func returnValue(v Value) {
returnTensor(t)
}
}

var dimSizerPool = make(map[int]*sync.Pool)

func borrowDimSizers(size int) []DimSizer {
pool, ok := dimSizerPool[size]
if !ok {
s := size
pool = &sync.Pool{
New: func() interface{} { return make([]DimSizer, s, s) },
}
dimSizerPool[size] = pool
}
return pool.Get().([]DimSizer)
}

func returnDimSizers(ds []DimSizer) {
pool, ok := dimSizerPool[cap(ds)]
if !ok {
return
}
for i := range ds {
ds[i] = nil
}
pool.Put(ds)
}
3 changes: 2 additions & 1 deletion typeSystem.go
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,8 @@ func inferNodeType(op Op, children ...*Node) (retVal hm.Type, err error) {
defer hm.ReturnFnType(fnt)
}

argTypes := make(hm.Types, len(children)+1)
argTypes := hm.BorrowTypes(len(children) + 1)
defer hm.ReturnTypes(argTypes)
for i, child := range children {
if argTypes[i], err = inferType(child); err != nil {
return nil, errors.Wrapf(err, "Failed to infer type of %v", child)
Expand Down
8 changes: 6 additions & 2 deletions vm_genera.go
Original file line number Diff line number Diff line change
Expand Up @@ -334,10 +334,10 @@ func (m *lispMachine) forward() (err error) {

m.enterLoggingContext()
for i, child := range children {
m.logf("child %v %v", child, child.Shape())
m.logf("child!! %v %v", child, child.Shape())
if child.Device() == n.Device() {
inputs[i] = child.boundTo.(*dualValue)
continue
// continue
}
// if child.boundTo != nil {
// dv := child.boundTo.(*dualValue)
Expand Down Expand Up @@ -391,6 +391,7 @@ func (m *lispMachine) forward() (err error) {
machineLogf("dvBindVar")
m.logf("dvBindVar")
if output, err = dvBindVar(op, inputs); err != nil {

}
if err = n.bind(output); err != nil {
return errors.Wrap(err, bindFail)
Expand Down Expand Up @@ -511,6 +512,9 @@ func (m *lispMachine) backward() (err error) {
if m.bwd < 0 {
return errors.New("no backprop queue")
}
if m.bwd >= len(m.q) {
return errors.New("Nothing to backprop")
}

instr := m.q[m.bwd]
m.watchedLogf("Differentiating op %v. Output: %v (%x)", instr, instr.output, instr.output.Hashcode())
Expand Down
8 changes: 4 additions & 4 deletions vm_genera_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -188,10 +188,10 @@ func TestLispMachineRepeatedRuns(t *testing.T) {
continue
}

assert.Equal([]float64{1, 4}, gradX.Data())
assert.Equal([]float64{0, 0, 0, 0, 1, 0}, gradY.Data())
assert.Equal([]float64{0, 1, 0}, gradZ.Data())
assert.Equal(1.0, gradC.Data())
assert.Equal([]float64{1, 4}, gradX.Data(), "run %d", i)
assert.Equal([]float64{0, 0, 0, 0, 1, 0}, gradY.Data(), "run %d", i)
assert.Equal([]float64{0, 1, 0}, gradZ.Data(), "run %d", i)
assert.Equal(1.0, gradC.Data(), "run %d", i)

// assert that the data has been unchanged
assert.Equal([]float64{0, 1}, x.Value().Data())
Expand Down

0 comments on commit a9403d5

Please sign in to comment.