Skip to content

Commit

Permalink
Op cleanup (#65)
Browse files Browse the repository at this point in the history
* Gorgonia now uses https://github.com/chewxy/hm as its type syste,
* Values are simplified - they're no longer wrapped in structs. `types.Tensor` are valid Values now
* Op are now easily extensible
  • Loading branch information
chewxy committed Nov 28, 2016
1 parent 644c692 commit cf80bd6
Show file tree
Hide file tree
Showing 100 changed files with 4,520 additions and 4,502 deletions.
13 changes: 13 additions & 0 deletions bench_typesystem_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
package gorgonia

import "testing"

func BenchmarkTypeSystem(b *testing.B) {
g := NewGraph()
x := NewTensor(g, Float64, 2, WithName("x"), WithShape(10, 10))
y := NewTensor(g, Float64, 2, WithName("y"), WithShape(10, 10))
op := newEBOByType(addOpType, Float64, Float64)
for i := 0; i < b.N; i++ {
inferNodeType(op, x, y)
}
}
11 changes: 7 additions & 4 deletions broadcast_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package gorgonia

import (
"io/ioutil"
"testing"

tf64 "github.com/chewxy/gorgonia/tensor/f64"
Expand Down Expand Up @@ -47,13 +48,14 @@ func TestBroadcast2(t *testing.T) {
var err error

xT := tf64.NewTensor(tf64.WithShape(2, 3), tf64.WithBacking(tf64.RangeFloat64(0, 6)))
yT := tf64.NewTensor(tf64.WithShape(2, 1), tf64.WithBacking([]float64{100, 200}))
yT := tf64.NewTensor(tf64.WithShape(2), tf64.WithBacking([]float64{100, 200}))

g = NewGraph()
x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x"))
y = NewVector(g, Float64, WithShape(2, 1), WithValue(yT), WithName("y"))
y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y"))
z, err = Broadcast(addOpType, x, y, NewBroadcastPattern(nil, []byte{1}))
if err != nil {
ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
t.Fatal(err)
}

Expand All @@ -65,10 +67,11 @@ func TestBroadcast2(t *testing.T) {

g = NewGraph()
x = NewMatrix(g, Float64, WithShape(2, 3), WithValue(xT), WithName("x"))
y = NewVector(g, Float64, WithShape(2, 1), WithValue(yT), WithName("y"))
y = NewVector(g, Float64, WithShape(2), WithValue(yT), WithName("y"))
z, err = Broadcast(addOpType, y, x, NewBroadcastPattern([]byte{1}, nil))
if err != nil {
t.Fatal(err)
ioutil.WriteFile("Broadcast.dot", []byte(g.ToDot()), 0644)
t.Fatalf("%+v", err)
}

m = NewLispMachine(g, ExecuteFwdOnly())
Expand Down
62 changes: 34 additions & 28 deletions collections.go
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
package gorgonia

import (
"bytes"
"fmt"
"sort"
"unsafe"
Expand Down Expand Up @@ -52,20 +51,20 @@ func (ns Nodes) Contains(want *Node) bool {

// Format implements fmt.Formatter, which allows Nodes to be differently formatted depending on the verbs
func (ns Nodes) Format(s fmt.State, c rune) {
delimiter := ","
delimiter := ", "
if s.Flag(' ') {
delimiter = " "
delimiter = " "
}
if s.Flag('+') {
delimiter = ",\n"
delimiter = ", \n"
}
switch c {
case 'd':
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%x", n.Hashcode())
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
fmt.Fprintf(s, "%s", delimiter)
}
}
s.Write([]byte("]"))
Expand All @@ -78,27 +77,26 @@ func (ns Nodes) Format(s fmt.State, c rune) {
fmt.Fprintf(s, "%s", n.Name())
}
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
fmt.Fprintf(s, "%s", delimiter)
}
}
s.Write([]byte("]"))
case 'Y':
if s.Flag('#') {
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%v", n.t)
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
}
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%v", n.t)
if i < len(ns)-1 {
fmt.Fprintf(s, "%s", delimiter)
}
s.Write([]byte("]"))
}
s.Write([]byte("]"))

case 'P':
s.Write([]byte("["))
for i, n := range ns {
fmt.Fprintf(s, "%p", n)
if i < len(ns)-1 {
fmt.Fprintf(s, "%s ", delimiter)
fmt.Fprintf(s, "%s", delimiter)
}
}
s.Write([]byte("]"))
Expand Down Expand Up @@ -143,6 +141,19 @@ func (ns Nodes) AllSameGraph() bool {
return true
}

func (ns Nodes) Equals(other Nodes) bool {
if len(ns) != len(other) {
return false
}

for _, n := range ns {
if !other.Contains(n) {
return false
}
}
return true
}

func (ns Nodes) mapSet() NodeSet { return NewNodeSet(ns...) }

func (ns Nodes) index(n *Node) int {
Expand Down Expand Up @@ -180,19 +191,14 @@ func (ns Nodes) remove(what *Node) Nodes {
return ns
}

/* TYPES */

type Types []Type

func (ts Types) String() string {
var buf bytes.Buffer
buf.WriteString("[")
for i, t := range ts {
buf.WriteString(t.String())
if i < len(ts)-1 {
buf.WriteString(", ")
func (ns Nodes) dimSizers() []DimSizer {
retVal := make([]DimSizer, len(ns))
for i, n := range ns {
if s, ok := n.op.(sizeOp); ok {
retVal[i] = s
} else {
retVal[i] = n.shape
}
}
buf.WriteString("]")
return buf.String()
return retVal
}
75 changes: 63 additions & 12 deletions collections_test.go
Original file line number Diff line number Diff line change
@@ -1,24 +1,25 @@
package gorgonia

import (
"fmt"
"testing"

"github.com/stretchr/testify/assert"
)

func TestSet(t *testing.T) {
func TestNodes(t *testing.T) {
assert := assert.New(t)
g := NewGraph()
n0 := newNodeFromPool(withGraph(g), WithName("n0"))
n1 := newNodeFromPool(withGraph(g), WithName("n1"))
n2 := newNodeFromPool(withGraph(g), WithName("n2"))
n3 := newNodeFromPool(withGraph(g), WithName("n3"))
n0 := newNode(withGraph(g), WithName("n0"))
n1 := newNode(withGraph(g), WithName("n1"))
n2 := newNode(withGraph(g), WithName("n2"))
n3 := newNode(withGraph(g), WithName("n3"))

// calculate hashcode first
n0.Hashcode()
n1.Hashcode()
n2.Hashcode()
n3.Hashcode()
n0h := n0.Hashcode()
n1h := n1.Hashcode()
n2h := n2.Hashcode()
n3h := n3.Hashcode()
t.Logf("%x, %x, %x, %x", n0.hash, n1.hash, n2.hash, n3.hash)

set := Nodes{n0, n1, n2, n3, n0, n0}
Expand All @@ -30,8 +31,17 @@ func TestSet(t *testing.T) {
}
assert.Equal(len(correct), len(set))

t.Log("Testing intersection")
t.Log("Test add")
set = Nodes{}
set = set.Add(n0)
set = set.Add(n2)
set = set.Add(n0)
set = set.Add(n3)
set = set.Add(n1)
correct = Nodes{n0, n2, n3, n1}
assert.Equal(correct, set)

t.Log("Testing intersection")
set = Nodes{n0, n2, n1, n3} // out of order, on purpose
other := Nodes{n0, n1}
inter := set.Intersect(other)
Expand All @@ -43,8 +53,8 @@ func TestSet(t *testing.T) {
assert.Equal(len(correct), len(inter))

t.Log("Testing difference")
n4 := newNodeFromPool(withGraph(g))
n5 := newNodeFromPool(withGraph(g))
n4 := newNode(withGraph(g))
n5 := newNode(withGraph(g))
set = Nodes{n3, n0, n1, n2}
other = Nodes{n0, n3, n4, n5}

Expand All @@ -54,4 +64,45 @@ func TestSet(t *testing.T) {
assert.Contains(diff, n)
}
assert.Equal(len(correct), len(diff))

t.Log("Testing replace")
set = Nodes{n0, n2, n1, n2, n1} // not yet a set
set = set.replace(n2, n3)
correct = Nodes{n0, n3, n1, n3, n1}
assert.Equal(correct, set)

t.Log("Formatting")
formats := []string{"% v", "%+v", "%d", "%v", "%#v", "%Y", "%P"}
correctFormats := []string{
"[n0 n1 n2 n3]",
`[n0,
n1,
n2,
n3]`,
fmt.Sprintf("[%x, %x, %x, %x]", n0h, n1h, n2h, n3h),
"[n0, n1, n2, n3]",
"[n0 :: <nil>, n1 :: <nil>, n2 :: <nil>, n3 :: <nil>]",
"[<nil>, <nil>, <nil>, <nil>]",
fmt.Sprintf("[%p, %p, %p, %p]", n0, n1, n2, n3),
}

set = Nodes{n0, n1, n2, n3}
for i, f := range formats {
s := fmt.Sprintf(f, set)
if s != correctFormats[i] {
t.Errorf("Format %q. Expected %q. Got %q", f, correctFormats[i], s)
}
}

// corner cases
set = Nodes{}
if set.AllSameGraph() {
t.Error("Empty list of nodes cannot be of the same graph!")
}

nAbnormal := newNode(withGraph(NewGraph()))
set = Nodes{n0, n1, nAbnormal, n2}
if set.AllSameGraph() {
t.Error("One node is in a different graph! This should have returned false")
}
}
10 changes: 5 additions & 5 deletions compile.go
Original file line number Diff line number Diff line change
Expand Up @@ -204,7 +204,7 @@ func codegen(inputs, sorted Nodes, df *dataflow) (prog *program, locationMap map
// if it's not mutable, there is no chance it will be overwritten
if node.isMutable() {
// if the instruction calls an extern (cBLAS or cuBlas), then we should preallocate the vector
if node.op.callsExtern() {
if node.op.CallsExtern() {
compileLogf("calls extern")
instr := newAlloc(node, nInterv.result)
instructions = append(instructions, instr)
Expand All @@ -223,7 +223,7 @@ func codegen(inputs, sorted Nodes, df *dataflow) (prog *program, locationMap map
if instrID, ok := lastWrites[read.id]; ok {
viaticum := instructions[instrID] // ;) - it IS on the way
if instr, ok := viaticum.(execOp); ok {
if instr.op.callsExtern() && !node.op.callsExtern() {
if instr.op.CallsExtern() && !node.op.CallsExtern() {
// the && bit is to make sure that if we have sequential cBLAS/cuBLAS calls,
// we just add it to the batch.
// sequential in this can mean several instructions apart. For example:
Expand All @@ -249,7 +249,7 @@ func codegen(inputs, sorted Nodes, df *dataflow) (prog *program, locationMap map

// check the overwrites - if the overwrite and the resulting register is the same,
// then use unsafe options when available
overwrites := node.op.overwriteInput()
overwrites := node.op.OverwritesInput()
if overwrites >= 0 {
compileLogf("Overwrites %d", overwrites)
overwritten := reads[overwrites]
Expand Down Expand Up @@ -306,12 +306,12 @@ func compileState(w io.Writer, g *ExprGraph, df *dataflow) {

if n.op != nil {
row[1] = fmt.Sprintf("%s", n.op)
overwrites := n.op.overwriteInput()
overwrites := n.op.OverwritesInput()
if overwrites >= 0 {
row[7] = fmt.Sprintf("%d", n.children[overwrites].ID())
}

if n.op.callsExtern() {
if n.op.CallsExtern() {
row[8] = "yes"
}
}
Expand Down
6 changes: 4 additions & 2 deletions const.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,8 @@ const (
sortFail = "Failed to sort"
cloneFail = "Failed to carry clone()"
clone0Fail = "Failed to carry clone0()"
nyiFail = "%s not yet implemented for %T"
nyiTypeFail = "%s not yet implemented for %T"
nyiFail = "%s not yet implemented for %v"
dtypeOfFail = "Failed to carry dtypeOf()"
mulFail = "Failed to carry Mul()"
applyOpFail = "Failed to carry applyOp()"
Expand Down Expand Up @@ -71,7 +72,8 @@ const (
reshapeFail = "Failed to reshape Tensor into %v. DataSize was: %d"
sliceFail = "Failed to slice Tensor with %v"
execFail = "Failed to execute %v"
autodiffFail = "Failed to differentiate %T"
autodiffFail = "Failed to differentiate %v"
undefinedOnShape = "%v undefined on shape %v"
)

var empty struct{}
Expand Down
8 changes: 4 additions & 4 deletions debug.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,10 @@ const DEBUG = true
// I use these instead of say, Delve because most of the time, the larger picture has to be known. Delve tends to give small picture views
var (
compileDev = false
shapeInferenceDev = false
typeSystemDev = false
symdiffDev = false
autodiffDev = false
shapeInferenceDev = true
typeSystemDev = true
symdiffDev = true
autodiffDev = true
machineDev = true
stabilizationDev = false
solverDev = false
Expand Down
19 changes: 17 additions & 2 deletions differentiation.go
Original file line number Diff line number Diff line change
Expand Up @@ -72,7 +72,13 @@ func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) {
for i := len(sortedNodes) - 1; i >= 0; i-- {
n := sortedNodes[i]
symdiffLogf("working on %v. Has %d children", n, len(n.children))
diffs := n.diffWRT()

var op SDOp
var ok bool
var diffs []bool
if op, ok = n.op.(SDOp); ok {
diffs = op.DiffWRT(len(n.children))
}

symdiffLogf("differentiable WRT: %v", diffs)
enterLoggingContext()
Expand Down Expand Up @@ -264,10 +270,19 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
if !node.isInput() {
symdiffLogf("differentiating %x (%v)", node.ID(), node.op)
enterLoggingContext()

var op SDOp
var childrenGrads Nodes
var ok bool

if op, ok = node.op.(SDOp); !ok {
// error
}

symdiffLogf("op: %v || optype: %v || node: %v || Children: %#Y || Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)
if childrenGrads, err = node.op.SymDiff(node.children, node, gradNode); err != nil {
if childrenGrads, err = op.SymDiff(node.children, node, gradNode); err != nil {
return nil, errors.Wrapf(err, "SymDiff for %v. OpType: %v. Node Type: %v. Children: %#v. Grad: %v", node.op, node.op.Type(), node.t, node.children, gradNode)

}
symdiffLogf("Derived(%d): %d", len(childrenGrads), childrenGrads)
leaveLoggingContext()
Expand Down
Loading

0 comments on commit cf80bd6

Please sign in to comment.