Skip to content
Permalink
Browse files

Op cleanup (#65)

* Gorgonia now uses https://github.com/chewxy/hm as its type syste,
* Values are simplified - they're no longer wrapped in structs. `types.Tensor` are valid Values now
* Op are now easily extensible
  • Loading branch information
chewxy committed Nov 28, 2016
1 parent 644c692 commit cf80bd61c3cd77c186ce98e577f9fb8f4ffebcf9
Showing with 4,520 additions and 4,502 deletions.
  1. +13 −0 bench_typesystem_test.go
  2. +7 −4 broadcast_test.go
  3. +34 −28 collections.go
  4. +63 −12 collections_test.go
  5. +5 −5 compile.go
  6. +4 −2 const.go
  7. +4 −4 debug.go
  8. +17 −2 differentiation.go
  9. +11 −8 differentiation_test.go
  10. +61 −134 dual.go
  11. +53 −0 dual_test.go
  12. +42 −16 equalities.go
  13. +92 −0 examples/logisticregression/io.go
  14. +22 −9 examples/logisticregression/main.go
  15. +2 −3 examples/stacked autoencoder/main.go
  16. +2 −2 examples/stacked autoencoder/neuron.go
  17. +1 −1 examples/stacked autoencoder/stackedDA.go
  18. +28 −46 gorgonia.go
  19. +9 −10 gorgonia_test.go
  20. +2 −2 graph_test.go
  21. +0 −21 math.go
  22. +8 −0 math_fast.go
  23. +21 −0 math_nooptim.go
  24. +2 −2 nn.go
  25. +53 −45 node.go
  26. +154 −0 node_test.go
  27. +57 −40 op.go
  28. +13 −10 op_infidel.go
  29. +139 −179 op_math.go
  30. +28 −31 op_nn.go
  31. +78 −71 op_reduction.go
  32. +99 −92 op_reduction_test.go
  33. +193 −217 op_tensor.go
  34. +146 −156 op_tensor_test.go
  35. +2 −2 op_test.go
  36. +12 −6 operations.go
  37. +18 −5 operations_test.go
  38. +5 −2 operatorLinAlg.go
  39. +15 −13 operatorLinAlg_const.go
  40. +372 −115 operatorPointwise_binary.go
  41. +183 −341 operatorPointwise_binary_test.go
  42. +19 −16 operatorPointwise_unary.go
  43. +32 −32 operatorPointwise_unary_test.go
  44. +1 −1 opt.go
  45. +18 −106 perf.go
  46. +2 −2 regalloc.go
  47. +724 −792 solvers.go
  48. +106 −17 solvers_test.go
  49. +3 −4 stabilization_test.go
  50. +2 −2 templates.go
  51. +7 −44 tensor/b/matop.go
  52. +31 −9 tensor/b/tensor.go
  53. +4 −4 tensor/f32/arith_linalg_norms.go
  54. +1 −1 tensor/f32/arith_linalg_norms_test.go
  55. +1 −1 tensor/f32/arith_linalg_svd.go
  56. +1 −1 tensor/f32/compat.go
  57. +7 −44 tensor/f32/matop.go
  58. +0 −31 tensor/f32/matop_test.go
  59. +31 −9 tensor/f32/tensor.go
  60. +1 −1 tensor/f32/test_test.go
  61. +1 −1 tensor/f32/utils.go
  62. +4 −4 tensor/f64/arith_linalg_norms.go
  63. +1 −1 tensor/f64/arith_linalg_norms_test.go
  64. +1 −1 tensor/f64/arith_linalg_svd.go
  65. +1 −1 tensor/f64/compat.go
  66. +7 −44 tensor/f64/matop.go
  67. +0 −31 tensor/f64/matop_test.go
  68. +31 −9 tensor/f64/tensor.go
  69. +1 −1 tensor/f64/utils.go
  70. +1 −1 tensor/i/compat.go
  71. +7 −44 tensor/i/matop.go
  72. +0 −31 tensor/i/matop_test.go
  73. +31 −9 tensor/i/tensor.go
  74. +1 −1 tensor/i/utils.go
  75. +64 −0 tensor/tensor.go
  76. +63 −79 tensor/types/accesspattern.go
  77. +37 −8 tensor/types/accesspattern_test.go
  78. +2 −3 tensor/types/perf.go
  79. +61 −5 tensor/types/shape.go
  80. +120 −0 tensor/types/shape_test.go
  81. +6 −0 tensor/types/types.go
  82. +85 −25 testsetup_test.go
  83. +65 −80 type.go
  84. +0 −71 typeAnalysis.go
  85. +0 −20 typeAtomic.go
  86. +0 −169 typeSet.go
  87. +79 −277 typeSystem.go
  88. +152 −187 typeSystem_test.go
  89. +159 −0 type_test.go
  90. +0 −44 typeclass.go
  91. +0 −174 typeclassSet.go
  92. +70 −134 utils.go
  93. +5 −196 values.go
  94. +155 −0 values_primitives.go
  95. +149 −0 values_utils.go
  96. +44 −0 values_utils_test.go
  97. +13 −9 vm_genera.go
  98. +4 −4 vm_genera_test.go
  99. +39 −32 vm_tape.go
  100. +0 −53 vm_tape_test.go
@@ -0,0 +1,13 @@
package gorgonia

import "testing"

func BenchmarkTypeSystem(b *testing.B) {
g := NewGraph()
x := NewTensor(g, Float64, 2, WithName("x"), WithShape(10, 10))
y := NewTensor(g, Float64, 2, WithName("y"), WithShape(10, 10))
op := newEBOByType(addOpType, Float64, Float64)
for i := 0; i < b.N; i++ {
inferNodeType(op, x, y)
}
}
@@ -1,6 +1,7 @@
package gorgonia

import (
"io/ioutil"
"testing"

tf64 "github.com/chewxy/gorgonia/tensor/f64"
@@ -47,13 +48,14 @@ func TestBroadcast2(t *testing.T) {
var err error

xT := tf64.NewTensor(tf64.WithShape(2, 3), tf64.WithBacking(tf64.RangeFloat64(0, 6)))
yT := tf64.NewTensor(tf64.WithShape(2, 1), tf64.WithBacking([]float64{100, 200}))
yT := tf64.NewTensor(tf64.WithShape(2), tf64.WithBacking([]float64{100, 200}))

g = NewGraph()
x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x"))
y = NewVector(g, Float64, WithShape(2, 1), WithValue(yT), WithName("y"))
y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y"))
z, err = Broadcast(addOpType, x, y, NewBroadcastPattern(nil, []byte{1}))
if err != nil {
ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
t.Fatal(err)
}

@@ -65,10 +67,11 @@ func TestBroadcast2(t *testing.T) {

g = NewGraph()
x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x"))
y = NewVector(g, Float64, WithShape(2, 1), WithValue(yT), WithName("y"))
y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y"))
z, err = Broadcast(addOpType, y, x, NewBroadcastPattern([]byte{1}, nil))
if err != nil {
t.Fatal(err)
ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
t.Fatalf("%+v", err)
}

m = NewLispMachine(g, ExecuteFwdOnly())
@@ -1,7 +1,6 @@
package gorgonia

import (
"bytes"
"fmt"
"sort"
"unsafe"
@@ -52,20 +51,20 @@ func (ns Nodes) Contains(want *Node) bool {

// Format implements fmt.Formatter, which allows Nodes to be differently formatted depending on the verbs
func (ns Nodes) Format(s fmt.State, c rune) {
delimiter := ","
delimiter := ", "
if s.Flag(' ') {
delimiter = " "
delimiter = " "
}
if s.Flag('+') {
delimiter = ",\n"
delimiter = ", \n"
}
switch c {
case 'd':
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%x", n.Hashcode())
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
fmt.Fprintf(s, "%s", delimiter)
}
}
s.Write([]byte("]"))
@@ -78,27 +77,26 @@ func (ns Nodes) Format(s fmt.State, c rune) {
fmt.Fprintf(s, "%s", n.Name())
}
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
fmt.Fprintf(s, "%s", delimiter)
}
}
s.Write([]byte("]"))
case 'Y':
if s.Flag('#') {
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%v", n.t)
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
}
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%v", n.t)
if i < len(ns)-1 {
fmt.Fprintf(s, "%s", delimiter)
}
s.Write([]byte("]"))
}
s.Write([]byte("]"))

case 'P':
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%p", n)
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
fmt.Fprintf(s, "%s", delimiter)
}
}
s.Write([]byte("]"))
@@ -143,6 +141,19 @@ func (ns Nodes) AllSameGraph() bool {
return true
}

func (ns Nodes) Equals(other Nodes) bool {
if len(ns) != len(other) {
return false
}

for _, n := range ns {
if !other.Contains(n) {
return false
}
}
return true
}

func (ns Nodes) mapSet() NodeSet { return NewNodeSet(ns...) }

func (ns Nodes) index(n *Node) int {
@@ -180,19 +191,14 @@ func (ns Nodes) remove(what *Node) Nodes {
return ns
}

/* TYPES */

type Types []Type

func (ts Types) String() string {
var buf bytes.Buffer
buf.WriteString("[")
for i, t := range ts {
buf.WriteString(t.String())
if i < len(ts)-1 {
buf.WriteString(", ")
func (ns Nodes) dimSizers() []DimSizer {
retVal := make([]DimSizer, len(ns))
for i, n := range ns {
if s, ok := n.op.(sizeOp); ok {
retVal[i] = s
} else {
retVal[i] = n.shape
}
}
buf.WriteString("]")
return buf.String()
return retVal
}
@@ -1,24 +1,25 @@
package gorgonia

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSet(t *testing.T) {
func TestNodes(t *testing.T) {
assert := assert.New(t)
g := NewGraph()
n0 := newNodeFromPool(withGraph(g), WithName("n0"))
n1 := newNodeFromPool(withGraph(g), WithName("n1"))
n2 := newNodeFromPool(withGraph(g), WithName("n2"))
n3 := newNodeFromPool(withGraph(g), WithName("n3"))
n0 := newNode(withGraph(g), WithName("n0"))
n1 := newNode(withGraph(g), WithName("n1"))
n2 := newNode(withGraph(g), WithName("n2"))
n3 := newNode(withGraph(g), WithName("n3"))

// calculate hashcode first
n0.Hashcode()
n1.Hashcode()
n2.Hashcode()
n3.Hashcode()
n0h := n0.Hashcode()
n1h := n1.Hashcode()
n2h := n2.Hashcode()
n3h := n3.Hashcode()
t.Logf("%x, %x, %x, %x", n0.hash, n1.hash, n2.hash, n3.hash)

set := Nodes{n0, n1, n2, n3, n0, n0}
@@ -30,8 +31,17 @@ func TestSet(t *testing.T) {
}
assert.Equal(len(correct), len(set))

t.Log("Testing intersection")
t.Log("Test add")
set = Nodes{}
set = set.Add(n0)
set = set.Add(n2)
set = set.Add(n0)
set = set.Add(n3)
set = set.Add(n1)
correct = Nodes{n0, n2, n3, n1}
assert.Equal(correct, set)

t.Log("Testing intersection")
set = Nodes{n0, n2, n1, n3} // out of order, on purpose
other := Nodes{n0, n1}
inter := set.Intersect(other)
@@ -43,8 +53,8 @@ func TestSet(t *testing.T) {
assert.Equal(len(correct), len(inter))

t.Log("Testing difference")
n4 := newNodeFromPool(withGraph(g))
n5 := newNodeFromPool(withGraph(g))
n4 := newNode(withGraph(g))
n5 := newNode(withGraph(g))
set = Nodes{n3, n0, n1, n2}
other = Nodes{n0, n3, n4, n5}

@@ -54,4 +64,45 @@ func TestSet(t *testing.T) {
assert.Contains(diff, n)
}
assert.Equal(len(correct), len(diff))

t.Log("Testing replace")
set = Nodes{n0, n2, n1, n2, n1} // not yet a set
set = set.replace(n2, n3)
correct = Nodes{n0, n3, n1, n3, n1}
assert.Equal(correct, set)

t.Log("Formatting")
formats := []string{"% v", "%+v", "%d", "%v", "%#v", "%Y", "%P"}
correctFormats := []string{
"[n0 n1 n2 n3]",
`[n0,
n1,
n2,
n3]`,
fmt.Sprintf("[%x, %x, %x, %x]", n0h, n1h, n2h, n3h),
"[n0, n1, n2, n3]",
"[n0 :: <nil>, n1 :: <nil>, n2 :: <nil>, n3 :: <nil>]",
"[<nil>, <nil>, <nil>, <nil>]",
fmt.Sprintf("[%p, %p, %p, %p]", n0, n1, n2, n3),
}

set = Nodes{n0, n1, n2, n3}
for i, f := range formats {
s := fmt.Sprintf(f, set)
if s != correctFormats[i] {
t.Errorf("Format %q. Expected %q. Got %q", f, correctFormats[i], s)
}
}

// corner cases
set = Nodes{}
if set.AllSameGraph() {
t.Error("Empty list of nodes cannot be of the same graph!")
}

nAbnormal := newNode(withGraph(NewGraph()))
set = Nodes{n0, n1, nAbnormal, n2}
if set.AllSameGraph() {
t.Error("One node is in a different graph! This should have returned false")
}
}
@@ -204,7 +204,7 @@ func codegen(inputs, sorted Nodes, df *dataflow) (prog *program, locationMap map
// if it's not mutable, there is no chance it will be overwritten
if node.isMutable() {
// if the instruction calls an extern (cBLAS or cuBlas), then we should preallocate the vector
if node.op.callsExtern() {
if node.op.CallsExtern() {
compileLogf("calls extern")
instr := newAlloc(node, nInterv.result)
instructions = append(instructions, instr)
@@ -223,7 +223,7 @@ func codegen(inputs, sorted Nodes, df *dataflow) (prog *program, locationMap map
if instrID, ok := lastWrites[read.id]; ok {
viaticum := instructions[instrID] // ;) - it IS on the way
if instr, ok := viaticum.(execOp); ok {
if instr.op.callsExtern() && !node.op.callsExtern() {
if instr.op.CallsExtern() && !node.op.CallsExtern() {
// the && bit is to make sure that if we have sequential cBLAS/cuBLAS calls,
// we just add it to the batch.
// sequential in this can mean several instructions apart. For example:
@@ -249,7 +249,7 @@ func codegen(inputs, sorted Nodes, df *dataflow) (prog *program, locationMap map

// check the overwrites - if the overwrite and the resulting register is the same,
// then use unsafe options when available
overwrites := node.op.overwriteInput()
overwrites := node.op.OverwritesInput()
if overwrites >= 0 {
compileLogf("Overwrites %d", overwrites)
overwritten := reads[overwrites]
@@ -306,12 +306,12 @@ func compileState(w io.Writer, g *ExprGraph, df *dataflow) {

if n.op != nil {
row[1] = fmt.Sprintf("%s", n.op)
overwrites := n.op.overwriteInput()
overwrites := n.op.OverwritesInput()
if overwrites >= 0 {
row[7] = fmt.Sprintf("%d", n.children[overwrites].ID())
}

if n.op.callsExtern() {
if n.op.CallsExtern() {
row[8] = "yes"
}
}
@@ -39,7 +39,8 @@ const (
sortFail = "Failed to sort"
cloneFail = "Failed to carry clone()"
clone0Fail = "Failed to carry clone0()"
nyiFail = "%s not yet implemented for %T"
nyiTypeFail = "%s not yet implemented for %T"
nyiFail = "%s not yet implemented for %v"
dtypeOfFail = "Failed to carry dtypeOf()"
mulFail = "Failed to carry Mul()"
applyOpFail = "Failed to carry applyOp()"
@@ -71,7 +72,8 @@ const (
reshapeFail = "Failed to reshape Tensor into %v. DataSize was: %d"
sliceFail = "Failed to slice Tensor with %v"
execFail = "Failed to execute %v"
autodiffFail = "Failed to differentiate %T"
autodiffFail = "Failed to differentiate %v"
undefinedOnShape = "%v undefined on shape %v"
)

var empty struct{}
@@ -16,10 +16,10 @@ const DEBUG = true
// I use these instead of say, Delve because most of the time, the larger picture has to be known. Delve tends to give small picture views
var (
compileDev = false
shapeInferenceDev = false
typeSystemDev = false
symdiffDev = false
autodiffDev = false
shapeInferenceDev = true
typeSystemDev = true
symdiffDev = true
autodiffDev = true
machineDev = true
stabilizationDev = false
solverDev = false
@@ -72,7 +72,13 @@ func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) {
for i := len(sortedNodes) - 1; i >= 0; i-- {
n := sortedNodes[i]
symdiffLogf("working on %v. Has %d children", n, len(n.children))
diffs := n.diffWRT()

var op SDOp
var ok bool
var diffs []bool
if op, ok = n.op.(SDOp); ok {
diffs = op.DiffWRT(len(n.children))
}

symdiffLogf("differentiable WRT: %v", diffs)
enterLoggingContext()
@@ -264,10 +270,19 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
if !node.isInput() {
symdiffLogf("differentiating %x (%v)", node.ID(), node.op)
enterLoggingContext()

var op SDOp
var childrenGrads Nodes
var ok bool

if op, ok = node.op.(SDOp); !ok {
// error
}

symdiffLogf("op: %v || optype: %v || node: %v || Children: %#Y || Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)
if childrenGrads, err = node.op.SymDiff(node.children, node, gradNode); err != nil {
if childrenGrads, err = op.SymDiff(node.children, node, gradNode); err != nil {
return nil, errors.Wrapf(err, "SymDiff for %v. OpType: %v. Node Type: %v. Children: %#v. Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)

}
symdiffLogf("Derived(%d): %d", len(childrenGrads), childrenGrads)
leaveLoggingContext()

0 comments on commit cf80bd6

Please sign in to comment.
You can’t perform that action at this time.