Skip to content

Commit

Permalink
Merge pull request #335 from gorgonia/vm-refacto
Browse files Browse the repository at this point in the history
Concurrency bug fixed
  • Loading branch information
owulveryck committed Oct 10, 2019
2 parents 1694575 + ac2d4cf commit d650d30
Show file tree
Hide file tree
Showing 2 changed files with 123 additions and 101 deletions.
74 changes: 74 additions & 0 deletions x/vm/chandb.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
package xvm

import "gorgonia.org/gorgonia"

type chanDB struct {
// map[tail][head]
dico map[int64]map[int64]chan gorgonia.Value
// map[head][tail]
reverseDico map[int64]map[int64]chan gorgonia.Value
inputNodeID int64
outputNodeID int64
}

func (c *chanDB) closeAll() {
for i := range c.dico {
for j := range c.dico[i] {
close(c.dico[i][j])
}
}
}

// upsert the channel to the DB, if id already exists it is overwritten
func (c *chanDB) upsert(channel chan gorgonia.Value, tail, head int64) {
if _, ok := c.dico[tail]; !ok {
c.dico[tail] = make(map[int64]chan gorgonia.Value, 0)
}
if _, ok := c.reverseDico[head]; !ok {
c.reverseDico[head] = make(map[int64]chan gorgonia.Value, 0)
}
c.dico[tail][head] = channel
c.reverseDico[head][tail] = channel
}

func newChanDB() *chanDB {
return &chanDB{
dico: make(map[int64]map[int64]chan gorgonia.Value, 0),
reverseDico: make(map[int64]map[int64]chan gorgonia.Value, 0),
inputNodeID: -1,
outputNodeID: -2,
}
}

func (c *chanDB) getAllFromTail(tail int64) []<-chan gorgonia.Value {
edges, ok := c.dico[tail]
if !ok {
return nil
}
output := make([]<-chan gorgonia.Value, 0, len(edges))
for _, edge := range edges {
output = append(output, edge)
}
return output
}

func (c *chanDB) getAllFromHead(head int64) []chan<- gorgonia.Value {
edges, ok := c.reverseDico[head]
if !ok {
return nil
}
output := make([]chan<- gorgonia.Value, 0, len(edges))
for _, edge := range edges {
output = append(output, edge)
}
return output
}

func (c *chanDB) getChan(tail, head int64) (chan gorgonia.Value, bool) {
v, ok := c.dico[tail][head]
return v, ok
}

func (c *chanDB) len() int {
return len(c.dico)
}
150 changes: 49 additions & 101 deletions x/vm/vm_gomachine.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,11 @@ package xvm

import (
"log"
"sync"

"gorgonia.org/gorgonia"
)

const (
inputNode int64 = -1
outputNode int64 = -2
)

// GoMachine is a computation VM for Gorgonia.
// Every edge of the graph is associated with a channel of Value.
// The channels are identified by two IDs, tail and head, which are the IDs of the starting node and the ending node.
Expand All @@ -27,99 +23,12 @@ type GoMachine struct {
db *chanDB
}

type chanDB struct {
// map[tail][head]
dico map[int64]map[int64]chan gorgonia.Value
// map[head][tail]
reverseDico map[int64]map[int64]chan gorgonia.Value
}

func (c *chanDB) closeAll() {
for i := range c.dico {
for j := range c.dico[i] {
close(c.dico[i][j])
}
}
}

// upsert the channel to the DB, if id already exists it is overwritten
func (c *chanDB) upsert(channel chan gorgonia.Value, tail, head int64) {
if _, ok := c.dico[tail]; !ok {
c.dico[tail] = make(map[int64]chan gorgonia.Value, 0)
}
if _, ok := c.reverseDico[head]; !ok {
c.reverseDico[head] = make(map[int64]chan gorgonia.Value, 0)
}
c.dico[tail][head] = channel
c.reverseDico[head][tail] = channel
}

func newChanDB() *chanDB {
return &chanDB{
dico: make(map[int64]map[int64]chan gorgonia.Value, 0),
reverseDico: make(map[int64]map[int64]chan gorgonia.Value, 0),
}
}

func (c *chanDB) getAllFromTail(tail int64) []<-chan gorgonia.Value {
edges, ok := c.dico[tail]
if !ok {
return nil
}
output := make([]<-chan gorgonia.Value, 0, len(edges))
for _, edge := range edges {
output = append(output, edge)
}
return output
}

func (c *chanDB) getAllFromHead(head int64) []chan<- gorgonia.Value {
edges, ok := c.reverseDico[head]
if !ok {
return nil
}
output := make([]chan<- gorgonia.Value, 0, len(edges))
for _, edge := range edges {
output = append(output, edge)
}
return output
}

func (c *chanDB) getChan(tail, head int64) (chan gorgonia.Value, bool) {
v, ok := c.dico[tail][head]
return v, ok
}

func (c *chanDB) len() int {
return len(c.dico)
}

// RunAll triggers all the goroutines and wait for the all the output channel to be filled with a value.
//
// Caution: there is no safety mechanism, and this method would never return (deadlock) in some circumstances.
func (g *GoMachine) RunAll() error {
g.populateChanDB()
nodesIt := g.g.Nodes()
if g.db.len() == 0 {
edgesIt := g.g.Edges()
for edgesIt.Next() {
currentEdge := edgesIt.Edge()
head := currentEdge.From().ID()
tail := currentEdge.To().ID()
g.db.upsert(make(chan gorgonia.Value, 0), tail, head)
}
for nodesIt.Next() {
currentNode := nodesIt.Node().(*gorgonia.Node)
if g.g.From(currentNode.ID()).Len() == 0 {
// Node is an input
g.db.upsert(make(chan gorgonia.Value, 0), currentNode.ID(), inputNode)
}
if g.g.To(currentNode.ID()).Len() == 0 {
// Node is an output
g.db.upsert(make(chan gorgonia.Value, 0), outputNode, currentNode.ID())
}
}
nodesIt.Reset()
}
for nodesIt.Next() {
currentNode := nodesIt.Node().(*gorgonia.Node)
// run all the nodes carrying an Op inside a go-routine
Expand All @@ -136,16 +45,16 @@ func (g *GoMachine) RunAll() error {
log.Fatal("chan edge not found")
}
}
go g.opWorker(currentNode, inputC, outputC)
go opWorker(currentNode, inputC, outputC)
// Send the input to the self nodes...
case currentNode.Op() == nil && currentNode.Value() != nil:
go g.valueFeeder(currentNode, outputC)
go valueFeeder(currentNode, outputC)
default:
log.Fatal("Yerk?")
}
}
// wait for all values to be computed
for _, outputC := range g.db.getAllFromTail(outputNode) {
for _, outputC := range g.db.getAllFromTail(g.db.outputNodeID) {
<-outputC
}
return nil
Expand All @@ -172,23 +81,62 @@ func NewGoMachine(g *gorgonia.ExprGraph) *GoMachine {
}
}

func (g *GoMachine) opWorker(n *gorgonia.Node, inputC []<-chan gorgonia.Value, outputC []chan<- gorgonia.Value) {
func opWorker(n *gorgonia.Node, inputC []<-chan gorgonia.Value, outputC []chan<- gorgonia.Value) {
vals := make([]gorgonia.Value, len(inputC))
var wg sync.WaitGroup
wg.Add(len(inputC))
for i := range inputC {
vals[i] = <-inputC[i]
go func(i int, vals []gorgonia.Value, inputC []<-chan gorgonia.Value) {
vals[i] = <-inputC[i]
wg.Done()
}(i, vals, inputC)
}
wg.Wait()
output, err := n.Op().Do(vals...)
if err != nil {
log.Fatal(err)
}
gorgonia.UnsafeLet(n, output)
wg.Add(len(outputC))
for i := range outputC {
outputC[i] <- output
go func(i int, outputC []chan<- gorgonia.Value) {
outputC[i] <- output
wg.Done()
}(i, outputC)
}
wg.Wait()
}

func (g *GoMachine) valueFeeder(n *gorgonia.Node, feedC []chan<- gorgonia.Value) {
func valueFeeder(n *gorgonia.Node, feedC []chan<- gorgonia.Value) {
var wg sync.WaitGroup
wg.Add(len(feedC))
for i := range feedC {
feedC[i] <- n.Value()
go func(i int, feedC []chan<- gorgonia.Value) {
feedC[i] <- n.Value()
}(i, feedC)
}
wg.Wait()
}

func (g *GoMachine) populateChanDB() error {
edgesIt := g.g.Edges()
for edgesIt.Next() {
currentEdge := edgesIt.Edge()
head := currentEdge.From().ID()
tail := currentEdge.To().ID()
g.db.upsert(make(chan gorgonia.Value, 0), tail, head)
}
nodesIt := g.g.Nodes()
for nodesIt.Next() {
currentNode := nodesIt.Node().(*gorgonia.Node)
if g.g.From(currentNode.ID()).Len() == 0 {
// Node is an input
g.db.upsert(make(chan gorgonia.Value, 0), currentNode.ID(), g.db.inputNodeID)
}
if g.g.To(currentNode.ID()).Len() == 0 {
// Node is an output
g.db.upsert(make(chan gorgonia.Value, 0), g.db.outputNodeID, currentNode.ID())
}
}
return nil
}

0 comments on commit d650d30

Please sign in to comment.