Skip to content

Commit

Permalink
Misc Fixes and Optimizations
Browse files Browse the repository at this point in the history
* Misc bugfix, small features etc

* Added concat op and Concat function to Gorgonia. Simple tests added too. To be fixed in the future to increase coverage

* Fixed a small bug with sliceOp's shape inference.

* Fixed up tests for Slice() which in turn led to discovery of various bugs in Repeat , Slice... etc. Hence cleaned them up

* Added UnsafeLet. Documented its use as well

* Slightly optimized FlatIterator

* Updated sliceOp such that it's *sliceOp that is the Op, instead of sliceOp. This fixes a number of training issues when used with UnsafeLet

* Added tests for said special case

* Fixed a bug in the tensors' comparison API

* Added tests for Gt and fixed #69 

* Cleaned up Dropout() function to fix bugs that occur if Float32 were used.
Fixed bugs with PreAllocDoer, IncrDoer, UnsafeDoer to handle returning of the same types
  • Loading branch information
chewxy committed Dec 19, 2016
1 parent 1478282 commit 2dce241
Show file tree
Hide file tree
Showing 41 changed files with 1,305 additions and 685 deletions.
2 changes: 1 addition & 1 deletion const.go
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ const (
repFail = "Failed to repeat Tensor along %d %d times"
reshapeFail = "Failed to reshape Tensor into %v. DataSize was: %d"
sliceFail = "Failed to slice Tensor with %v"
execFail = "Failed to execute %v"
execFail = "Failed to execute %v in node %v"
autodiffFail = "Failed to differentiate %v"
undefinedOnShape = "%v undefined on shape %v"
)
Expand Down
10 changes: 5 additions & 5 deletions debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,11 +16,11 @@ const DEBUG = true
// I use these instead of say, Delve because most of the time, the larger picture has to be known. Delve tends to give small picture views
var (
compileDev = false
shapeInferenceDev = true
typeSystemDev = true
symdiffDev = true
autodiffDev = true
machineDev = true
shapeInferenceDev = false
typeSystemDev = false
symdiffDev = false
autodiffDev = false
machineDev = false
stabilizationDev = false
solverDev = false
)
Expand Down
13 changes: 4 additions & 9 deletions differentiation.go
Original file line number Diff line number Diff line change
Expand Up @@ -205,6 +205,7 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
// when iterating through the nondes in reverse topological order
nodeGradMap := make(map[*Node]Nodes)
for i, n := range outputs {
symdiffLogf("Adding outputs for %x", n.ID())
nodeGradMap[n] = Nodes{gradOutputs[i]}
}

Expand All @@ -218,13 +219,8 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
symdiffLogf("Sorted: %d", sortedNodes)
symdiffLogf("nodeGradMap: %+#d", FmtNodeMap(nodeGradMap))
enterLoggingContext()
// for i := len(sortedNodes) - 1; i >= 0; i-- {
// node := sortedNodes[i]

for _, node := range sortedNodes {
// if !activeNodes.Contains(node) {
// autodiffLogf("skipping %d", node.ID())
// continue
// }
if _, ok := activeNodes[node]; !ok {
symdiffLogf("skipping %x", node.ID())
continue
Expand Down Expand Up @@ -282,9 +278,9 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
symdiffLogf("op: %v || optype: %v || node: %v || Children: %#Y || Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)
if childrenGrads, err = op.SymDiff(node.children, node, gradNode); err != nil {
return nil, errors.Wrapf(err, "SymDiff for %v. OpType: %v. Node Type: %v. Children: %#v. Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)

}
symdiffLogf("Derived(%d): %d", len(childrenGrads), childrenGrads)

symdiffLogf("Derived(%d): %P", len(childrenGrads), childrenGrads)
leaveLoggingContext()

diffs := node.diffWRT()
Expand All @@ -294,7 +290,6 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
childGrad := childrenGrads[i]

if differentiable {
// node.derives = append(node.derives, childGrad)
childGrad.setGroup(gradClust)
if grads, ok := nodeGradMap[child]; ok {
grads = append(grads, childGrad)
Expand Down
19 changes: 13 additions & 6 deletions dual.go
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,10 @@ func (dv *dualValue) String() string {
func (dv *dualValue) sanity() error {
// check that d and v are the same type

if !TypeOf(dv.Value).Eq(TypeOf(dv.d)) {
return errors.New("DualValues do not have the same types")
dvv := TypeOf(dv.Value)
dvd := TypeOf(dv.d)
if !dvv.Eq(dvd) {
return errors.Errorf("DualValues do not have the same types: %v and %v", dvv, dvd)
}

// TODO: check that the shapes are the same
Expand Down Expand Up @@ -194,20 +196,25 @@ func dvBindVar(op Op, inputs []*dualValue) (retVal *dualValue, err error) {
}
}

//TODO test vecvecdot divBind0

// doesn't alloc a dualValue, and reuses whatever that is there, and zeroes out the deriv
func dvBind0(op Op, retVal *dualValue, inputs []*dualValue) (err error) {
prealloc := retVal.Value

vals := idValue(inputs)

var ret Value
if pd, ok := op.(UsePreallocDoer); ok {
ret, err = pd.UsePreallocDo(prealloc, vals...)
} else {
if ret, err = op.Do(vals...); err != nil {
return errors.Wrap(err, opDoFail)
if err == nil {
goto next
}
}
if ret, err = op.Do(vals...); err != nil {
return errors.Wrap(err, opDoFail)
}

next:
if err != nil {
return
}
Expand Down
40 changes: 32 additions & 8 deletions gorgonia.go
Original file line number Diff line number Diff line change
Expand Up @@ -230,19 +230,43 @@ func Grad(cost *Node, WRTs ...*Node) (retVal []*Node, err error) {
// Let binds a Value to a node that is a variable. A variable is represented as a *Node with no Op.
// It is equivalent to :
// x = 2
func Let(n *Node, be interface{}) (err error) {
func Let(n *Node, be interface{}) error {
if !n.isInput() {
return errors.New("Cannot bind a value to a non input node")
}

var val Value
if val, _, _, err = anyToValue(be); err != nil {
return errors.Wrapf(err, anyToValueFail, be, be)
}
return UnsafeLet(n, be)
}

// TODO: runtime type checking
n.bind(val)
return
// UnsafeLet binds a Value to any node, not just a variable node. This means that you can use it to change any node's value at the runtime of the graph. UNSAFE!
//
// Additional notes: if `be` is a types.Slice, and the node's op is a sliceOp or sliceIncrOp, the op's slice will be replaced with the new slice.
func UnsafeLet(n *Node, be interface{}) error {
switch v := be.(type) {
case types.Slice:
switch so := n.op.(type) {
case *sliceOp:
so.Slice = v
n.op = so
case sliceIncrOp:
so.Slice = v
n.op = so
default:
return errors.Errorf("Trying to Let() a node with a slice. Node's op is %v, not sliceOp", n.op)
}

case Value:
n.bind(v)
default:
var val Value
var err error
if val, _, _, err = anyToValue(be); err != nil {
return errors.Wrapf(err, anyToValueFail, be, be)
}

n.bind(val)
}
return nil
}

// Set is the equivalent of doing this:
Expand Down
43 changes: 18 additions & 25 deletions nn.go
Original file line number Diff line number Diff line change
@@ -1,10 +1,6 @@
package gorgonia

import (
"fmt"

"github.com/pkg/errors"
)
import "github.com/pkg/errors"

// BinaryXent is a convenience function for doing binary crossentropy stuff.
// The formula is as below:
Expand Down Expand Up @@ -67,37 +63,34 @@ func Dropout(x *Node, prob float64) (retVal *Node, err error) {
return x, nil
}

var low, high float64
if prob < 0 {
low = prob
high = -prob
} else {
low = -prob
high = prob
}

var dt Dtype
if dt, err = dtypeOf(x.t); err != nil {
return nil, errors.Wrap(err, dtypeOfFail)
}

m := UniformRandomNode(x.g, dt, low, high, x.shape...)
if retVal, err = Mul(x, m); err != nil {
return nil, errors.Wrap(err, mulFail)
}

var v Value
var opp, pr Value // opp = 1 per p
switch dt {
case Float64:
v, _ = anyToScalar(1.0 / prob)
opp, _ = anyToScalar(1.0 / prob)
pr, _ = anyToScalar(prob)
case Float32:
v, _ = anyToScalar(float32(1.0 / prob))
opp, _ = anyToScalar(float32(1.0 / prob))
pr, _ = anyToScalar(float32(prob))
default:
// TODO: use errors package for this panic?
panic(fmt.Sprintf("Dtype %v not yet implemented for dropout", dt))
return nil, errors.Errorf(nyiTypeFail, "Dropout()", dt)
}

c := NewConstant(v)
p := NewConstant(pr)
c := NewConstant(opp)

m := UniformRandomNode(x.g, dt, 0, 1, x.shape...)
if retVal, err = Gt(m, p, true); err != nil {
return nil, errors.Wrap(err, "Greater Than failed")
}

if retVal, err = HadamardProd(x, retVal); err != nil {
return nil, errors.Wrap(err, mulFail)
}

return HadamardDiv(retVal, c)
}
Expand Down
10 changes: 8 additions & 2 deletions node.go
Original file line number Diff line number Diff line change
Expand Up @@ -201,9 +201,10 @@ func (n *Node) ID() int { return int(uintptr(unsafe.Pointer(n))) }

// helper functions to help compilation process
func (n *Node) isArg() bool { return n.op == nil }
func (n *Node) isInput() bool { return n.isArg() && !n.isStmt }
func (n *Node) isInput() bool { return (n.isArg() || n.isRandom()) && !n.isStmt }
func (n *Node) isMutable() bool { return !n.isInput() && n.op.ReturnsPtr() }
func (n *Node) isConstant() bool { _, ok := n.op.(constant); return ok }
func (n *Node) isRandom() bool { _, ok := n.op.(randomOp); return ok }

func (n *Node) isRoot() bool {
if n.g == nil {
Expand Down Expand Up @@ -298,7 +299,6 @@ func (n *Node) Grad() (Value, error) {
return dv.d, nil
}
if n.deriv != nil {
logf("Getting from n.deriv")
return n.deriv.Value(), nil
}

Expand Down Expand Up @@ -475,6 +475,12 @@ func (n *Node) bind(v Value) error {
if vdv == dv {
return nil
}
if n.isRandom() {
// then simply replace the value in it
dv.Value = vdv.Value
return nil
}

panic("Undefined behaviour") // no seriously there literally is no defined behaviour of what should the right thing be. I'll come back to this TODO.
}
dv.Value = v
Expand Down
14 changes: 14 additions & 0 deletions op.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,14 @@ import (

"github.com/chewxy/gorgonia/tensor/types"
"github.com/chewxy/hm"
"github.com/pkg/errors"
)

type DimSizer interface {
DimSize(int) (int, error)
}

// ShapesToDimSizers is a convenience function to convert a slice of types.Shape to a slice of DimSizer
func ShapesToDimSizers(shapes []types.Shape) []DimSizer {
retVal := make([]DimSizer, len(shapes))
for i, s := range shapes {
Expand All @@ -22,6 +24,18 @@ func ShapesToDimSizers(shapes []types.Shape) []DimSizer {
return retVal
}

// DimSizersToShapes is a convenience function to convert a slice of DimSizer to a slice of types.Shape. It will return an error if any of them isn't a types.Shape
func DimSizersToShapes(ds []DimSizer) ([]types.Shape, error) {
retVal := make([]types.Shape, len(ds))
var ok bool
for i, d := range ds {
if retVal[i], ok = d.(types.Shape); !ok {
return nil, errors.Errorf("Dimsizer %d is not a Shape.", i)
}
}
return retVal, nil
}

// An Op is a symbolic representation of an operation
// Think of them as functions, taking an input (or multiple), and outputting something
//
Expand Down
26 changes: 16 additions & 10 deletions op_math.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,10 @@ func (op elemBinOp) DoDiff(inputs Nodes, output *Node) (err error) {

b := op.ʘBinaryOperator.binOpType()
if err = ʘBinOpDiffFns[b](inputs[0], inputs[1], output); err != nil {
return errors.Wrapf(err, autodiffFail, b)
if _, ok := err.(AutoDiffError); !ok {
return errors.Wrapf(err, autodiffFail, b)
}
err = nil
}

//handle scalar gradients
Expand Down Expand Up @@ -309,8 +312,8 @@ func (op elemBinOp) UsePreallocDo(prealloc Value, inputs ...Value) (retVal Value
return op.Do(inputs...)
}

if pd, ok := op.ʘBinaryOperator.(UsePreallocDoer); ok {
return pd.UsePreallocDo(prealloc, inputs...)
if pd, ok := op.ʘBinaryOperator.(usePreallocDoerBinOp); ok {
return pd.UsePreallocDo(prealloc, op.retSame, inputs...)
}

return op.Do(inputs...)
Expand All @@ -322,8 +325,8 @@ func (op elemBinOp) UnsafeDo(inputs ...Value) (retVal Value, err error) {
return op.Do(inputs...)
}

if ud, ok := op.ʘBinaryOperator.(UnsafeDoer); ok {
return ud.UnsafeDo(inputs...)
if ud, ok := op.ʘBinaryOperator.(unsafeDoerBinOp); ok {
return ud.UnsafeDo(op.retSame, inputs...)
}
return op.Do(inputs...)
}
Expand All @@ -344,12 +347,15 @@ func (op elemBinOp) IncrDo(incr Value, inputs ...Value) (err error) {
return
}

if id, ok := op.ʘBinaryOperator.(IncrDoer); ok {
return id.IncrDo(incr, inputs...)
if id, ok := op.ʘBinaryOperator.(incrDoerBinOp); ok {
return id.IncrDo(incr, op.retSame, inputs...)
}

panic("unreachable")
}

func (op elemBinOp) String() string { return fmt.Sprintf("%v %t", op.ʘBinaryOperator, op.retSame) }

// Fulfils the BinaryOp interface
func (op elemBinOp) IsBinary() bool { return true }

Expand Down Expand Up @@ -567,16 +573,16 @@ func (op linAlgBinOp) InferShape(inputs ...DimSizer) (retVal types.Shape, err er
switch op.āBinaryOperator {
case matMulOperator:
if op.transA {
x = transpose(x)
x = transpose2D(x)
}
if op.transB {
y = transpose(y)
y = transpose2D(y)
}

retVal = types.Shape{x[0], y[1]}
case matVecMulOperator:
if op.transA {
x = transpose(x)
x = transpose2D(x)
}

if x[0] != y[0] && x[1] != y[0] {
Expand Down
4 changes: 2 additions & 2 deletions op_nn.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,8 +62,6 @@ func (op randomOp) Type() hm.Type {
}

func (op randomOp) InferShape(...DimSizer) (types.Shape, error) { return op.shape, nil }
func (op randomOp) DiffWRT(i int) []bool { r := make([]bool, i); return r }
func (op randomOp) SymDiff(Nodes, *Node, *Node) (Nodes, error) { return nil, nondiffErr(op) }

func (op randomOp) Do(...Value) (retVal Value, err error) {
if op.shape.IsScalar() {
Expand Down Expand Up @@ -114,6 +112,7 @@ func (op randomOp) Do(...Value) (retVal Value, err error) {
backing := Binomial64(op.a, op.b, op.shape...)
retVal = tf64.NewTensor(tf64.WithBacking(backing), tf64.WithShape(op.shape...))
}
return
case Float32:
switch op.which {
case uniform:
Expand All @@ -126,6 +125,7 @@ func (op randomOp) Do(...Value) (retVal Value, err error) {
backing := Binomial32(op.a, op.b, op.shape...)
retVal = tf32.NewTensor(tf32.WithBacking(backing), tf32.WithShape(op.shape...))
}
return
default:
return nil, errors.Errorf(nyiFail, "randomOp.do() for non-scalar", op.dt)
}
Expand Down
Loading

0 comments on commit 2dce241

Please sign in to comment.