Skip to content

Commit

Permalink
In repeated operations the VM spends a lot of time on integer equalit…
Browse files Browse the repository at this point in the history
…y in the ExprGraph.Node() method. This fixes it (#285)
  • Loading branch information
chewxy committed Jun 1, 2019
1 parent 0c67abd commit e78451c
Show file tree
Hide file tree
Showing 4 changed files with 26 additions and 8 deletions.
26 changes: 19 additions & 7 deletions graph.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ type ExprGraph struct {

all Nodes

byId map[int64]int
byHash map[uint32]*Node
evac map[uint32]Nodes
to map[*Node]Nodes
Expand All @@ -39,6 +40,7 @@ func WithGraphName(name string) graphconopt {
// NewGraph creates a new graph. Duh
func NewGraph(opts ...graphconopt) *ExprGraph {
g := &ExprGraph{
byId: make(map[int64]int),
byHash: make(map[uint32]*Node),
evac: make(map[uint32]Nodes),
to: make(map[*Node]Nodes),
Expand Down Expand Up @@ -88,6 +90,7 @@ func (g *ExprGraph) Clone() interface{} {
}
}

g2.byId = make(map[int64]int)
g2.byHash = make(map[uint32]*Node)
for k, v := range g.byHash {
g2.byHash[k] = mapping[v]
Expand Down Expand Up @@ -497,17 +500,25 @@ func (g *ExprGraph) removeAllEdgesFrom(n *Node) {
// Node returns the node in the graph with the given ID.
func (g *ExprGraph) Node(id int64) graph.Node {
// n := (*Node)(unsafe.Pointer(uintptr(id)))
for _, n := range g.all {
if n.id == id {
return n
}
}
return nil
// for _, n := range g.all {
// if n.id == id {
// return n
// }
// }
// return nil
return g.node(id)
}

func (g *ExprGraph) node(id int64) *Node {
for _, n := range g.all {
if idx, ok := g.byId[id]; ok {
if idx >= len(g.all) {
return nil
}
return g.all[idx]
}
for i, n := range g.all {
if n.id == id {
g.byId[id] = i
return n
}
}
Expand Down Expand Up @@ -662,6 +673,7 @@ func (g *ExprGraph) subgraph(ns Nodes, findMissing bool, opts ...Nodes) *ExprGra

retVal := &ExprGraph{
all: ns,
byId: make(map[int64]int),
byHash: g.byHash,
evac: g.evac,
to: g.to,
Expand Down
4 changes: 4 additions & 0 deletions nn.go
Original file line number Diff line number Diff line change
Expand Up @@ -278,6 +278,10 @@ func MaxPool2D(x *Node, kernel tensor.Shape, pad, stride []int) (*Node, error) {
return ApplyOp(op, x)
}

func MaxPool1D(x *Node, kernel, pad, stride int) (*Node, error) {
return MaxPool2D(x, tensor.Shape{1, kernel}, []int{0, pad}, []int{1, stride})
}

func BatchNorm(x, scale, bias *Node, momentum, epsilon float64) (retVal, γ, β *Node, op *BatchNormOp, err error) {
dt, err := dtypeOf(x.Type())
if err != nil {
Expand Down
2 changes: 2 additions & 0 deletions op_tensor.go
Original file line number Diff line number Diff line change
Expand Up @@ -489,6 +489,8 @@ type sliceOp struct {
d int // how many dimensions were the original tensor
}

func (op *sliceOp) IsSlice() tensor.Slice { return op.Slice }

func newSliceOp(s tensor.Slice, along, d int) *sliceOp {
return &sliceOp{
Slice: s,
Expand Down
2 changes: 1 addition & 1 deletion operations.go
Original file line number Diff line number Diff line change
Expand Up @@ -83,7 +83,7 @@ func BatchedMatMul(a, b *Node) (retVal *Node, err error) {
// OuterProd returns a Node representing the outer product of two vectors. This function will return an error if both input nodes are not vectors
func OuterProd(a, b *Node) (retVal *Node, err error) {
if !a.IsVector() || !b.IsVector() {
return nil, errors.New("Expected only vectors to be able to do OuterProd") //for now
return nil, errors.Errorf("Expected only vectors to be able to do OuterProd. %v is %v. %v is %v", a, a.Shape(), b, b.Shape()) //for now
}

// TODO: maybe align shapes?
Expand Down

0 comments on commit e78451c

Please sign in to comment.