Skip to content
Permalink
Browse files

gorgonia: update for changes in gonum/graph iterator API

  • Loading branch information
kortschak committed Sep 28, 2018
1 parent 8cd0746 commit 99d21e568d9913e34737aea3957b41d273514adf
Showing with 24 additions and 18 deletions.
  1. +1 −1 compile.go
  2. +1 −1 const.go
  3. +7 −4 differentiation.go
  4. +7 −6 graph.go
  5. +4 −4 graph_test.go
  6. +2 −1 node.go
  7. +2 −1 regalloc.go
@@ -18,7 +18,7 @@ func Compile(g *ExprGraph) (prog *program, locMap map[*Node]register, err error)
enterLogScope()
defer leaveLogScope()

if len(g.Nodes()) == 0 {
if g.Nodes().Len() == 0 {
err = errors.Errorf("Cannot compile an empty graph")
return
}
@@ -34,7 +34,7 @@ const (

// error messages
sortFail = "Failed to sort"
cloneFail = "Failed to carry clone()"
cloneFail = "Failed to carry clone(%v)"
clone0Fail = "Failed to carry clone0()"
nyiTypeFail = "%s not yet implemented for %T"
nyiFail = "%s not yet implemented for %v"
@@ -1,6 +1,9 @@
package gorgonia

import "github.com/pkg/errors"
import (
"github.com/pkg/errors"
"gonum.org/v1/gonum/graph"
)

/*
This file holds code for symbolic differentiation.
@@ -100,7 +103,7 @@ func backwardDiffAnalysis(wrt, sortedNodes Nodes) (retVal NodeSet, err error) {
}
g := n.g
for _, child := range n.children {
parents := g.To(child.ID())
parents := graph.NodesOf(g.To(child.ID()))

symdiffLogf("parents of %v: %v", child, graphNodeToNode(parents))
if len(parents) == 1 && len(child.children) > 0 {
@@ -156,8 +159,8 @@ func Backpropagate(outputs, gradOutputs, wrt Nodes) (retVal Nodes, err error) {
for i := 0; i < len(g.AllNodes()); i++ {
n := g.AllNodes()[i]

fr := len(g.From(n.ID()))
to := len(g.To(n.ID()))
fr := g.From(n.ID()).Len()
to := g.To(n.ID()).Len()

if fr == 0 && to == 0 && !n.isConstant() && !n.isInput() {
g.RemoveNode(n)
@@ -6,6 +6,7 @@ import (

"github.com/awalterschulze/gographviz"
"gonum.org/v1/gonum/graph"
"gonum.org/v1/gonum/graph/iterator"
)

// ExprGraph is a data structure for a directed acyclic graph (of expressions). This structure is the main entry point
@@ -520,22 +521,22 @@ func (g *ExprGraph) Has(nodeid int64) bool {
}

// Nodes returns all the nodes in the graph.
func (g *ExprGraph) Nodes() []graph.Node {
func (g *ExprGraph) Nodes() graph.Nodes {
// nodes := make([]graph.Node, len(g.from))
ns := g.AllNodes()

nodes := nodeToGraphNode(ns)
return nodes
return iterator.NewOrderedNodes(nodes)
}

// AllNodes is like Nodes, but returns Nodes instead of []graph.Node.
// Nodes() has been reserved for the graph.Directed interface, so this one is named AllNodes instead
func (g *ExprGraph) AllNodes() Nodes { return g.all }

// From returns all nodes in g that can be reached directly from n.
func (g *ExprGraph) From(nodeid int64) []graph.Node {
func (g *ExprGraph) From(nodeid int64) graph.Nodes {
if n := g.node(nodeid); n != nil {
return nodeToGraphNode(n.children)
return iterator.NewOrderedNodes(nodeToGraphNode(n.children))
}
return nil
}
@@ -583,7 +584,7 @@ func (g *ExprGraph) HasEdgeFromTo(u, v int64) bool {
}

// To returns all nodes in g that can reach directly to n.
func (g *ExprGraph) To(nid int64) []graph.Node {
func (g *ExprGraph) To(nid int64) graph.Nodes {
n := g.node(nid)
if n == nil {
return nil
@@ -592,7 +593,7 @@ func (g *ExprGraph) To(nid int64) []graph.Node {
ns := g.to[n]
ns = ns.Set()
g.to[n] = ns
return nodeToGraphNode(ns)
return iterator.NewOrderedNodes(nodeToGraphNode(ns))
}

// subgraph is basically a subset of nodes. This is useful for compiling sub sections of the graph
@@ -45,18 +45,18 @@ func TestGraphBasics(t *testing.T) {
assert.Equal(correctTo, g.to[x])

correctTo = Nodes{xy}
assert.Equal(correctTo, graphNodeToNode(g.To(y.ID())))
assert.Equal(correctTo, graphNodeToNode(g.To(x.ID())))
assert.Equal(correctTo, graphNodeToNode(graph.NodesOf(g.To(y.ID()))))
assert.Equal(correctTo, graphNodeToNode(graph.NodesOf(g.To(x.ID()))))

assert.Equal(3, len(g.Nodes()))
assert.Equal(3, g.Nodes().Len())

// Now, time to deal with constants
xy1 := Must(Add(xy, onef64))
assert.Nil(onef64.g)
assert.Equal(g, xy1.g)

var containsOne bool
for _, node := range g.Nodes() {
for _, node := range graph.NodesOf(g.Nodes()) {
n := node.(*Node)
if n.Hashcode() == onef64.Hashcode() {
containsOne = true
@@ -11,6 +11,7 @@ import (
"github.com/awalterschulze/gographviz"
"github.com/chewxy/hm"
"github.com/pkg/errors"
"gonum.org/v1/gonum/graph"
"gorgonia.org/tensor"
)

@@ -537,7 +538,7 @@ func (n *Node) RestrictedToDot(up, down int) string {
origLen := len(upQ)
for i := 0; i < origLen; i++ {
qn := upQ[i]
toQN := graphNodeToNode(g.To(qn.ID()))
toQN := graphNodeToNode(graph.NodesOf(g.To(qn.ID())))
upQ = append(upQ, toQN...)
ns = append(ns, toQN...)
}
@@ -4,6 +4,7 @@ import (
"fmt"

"github.com/xtgo/set"
"gonum.org/v1/gonum/graph"
)

// this file holds all the code that relates to register allocation
@@ -184,7 +185,7 @@ func (ra *regalloc) allocMutableOp(node *Node, nInterv *interval) {
compileLogf("Read %v", reads)

var letStmts Nodes
for _, parent := range node.g.To(node.ID()) {
for _, parent := range graph.NodesOf(node.g.To(node.ID())) {
n := parent.(*Node)
compileLogf("Parent: %v | %T", n, n.op)
if n.isStmt {

0 comments on commit 99d21e5

Please sign in to comment.
You can’t perform that action at this time.