Skip to content

Commit

Permalink
Merge pull request #42 from MichaelMure/walk-functional-opts
Browse files Browse the repository at this point in the history
rework the graph walking functions with functional options
  • Loading branch information
Stebalien committed Jul 22, 2019
2 parents 625228c + 27ea6f8 commit 89dd2ad
Show file tree
Hide file tree
Showing 2 changed files with 154 additions and 47 deletions.
190 changes: 147 additions & 43 deletions merkledag.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func FetchGraph(ctx context.Context, root cid.Cid, serv ipld.DAGService) error {
}

// FetchGraphWithDepthLimit fetches all nodes that are children to the given
// node down to the given depth. maxDetph=0 means "only fetch root",
// node down to the given depth. maxDepth=0 means "only fetch root",
// maxDepth=1 means "fetch root and its direct children" and so on...
// maxDepth=-1 means unlimited.
func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, serv ipld.DAGService) error {
Expand Down Expand Up @@ -195,9 +195,10 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
return false
}

// If we have a ProgressTracker, we wrap the visit function to handle it
v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
if v == nil {
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visit)
return WalkDepth(ctx, GetLinksDirect(ng), root, visit, Concurrent(), WithRoot())
}

visitProgress := func(c cid.Cid, depth int) bool {
Expand All @@ -207,7 +208,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
}
return false
}
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress)
return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, Concurrent(), WithRoot())
}

// GetMany gets many nodes from the DAG at once.
Expand Down Expand Up @@ -281,30 +282,143 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks {
}
}

// defaultConcurrentFetch is the default maximum number of concurrent fetches
// that 'fetchNodes' will start at a time
const defaultConcurrentFetch = 32

// walkOptions represent the parameters of a graph walking algorithm
type walkOptions struct {
WithRoot bool
Concurrency int
ErrorHandler func(c cid.Cid, err error) error
}

// WalkOption is a setter for walkOptions
type WalkOption func(*walkOptions)

func (wo *walkOptions) addHandler(handler func(c cid.Cid, err error) error) {
if wo.ErrorHandler != nil {
wo.ErrorHandler = func(c cid.Cid, err error) error {
return handler(c, wo.ErrorHandler(c, err))
}
} else {
wo.ErrorHandler = handler
}
}

// WithRoot is a WalkOption indicating that the root node should be visited
func WithRoot() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.WithRoot = true
}
}

// Concurrent is a WalkOption indicating that node fetching should be done in
// parallel, with the default concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrent() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.Concurrency = defaultConcurrentFetch
}
}

// Concurrency is a WalkOption indicating that node fetching should be done in
// parallel, with a specific concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrency(worker int) WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.Concurrency = worker
}
}

// IgnoreErrors is a WalkOption indicating that the walk should attempt to
// continue even when an error occur.
func IgnoreErrors() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(func(c cid.Cid, err error) error {
return nil
})
}
}

// IgnoreMissing is a WalkOption indicating that the walk should continue when
// a node is missing.
func IgnoreMissing() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(func(c cid.Cid, err error) error {
if err == ipld.ErrNotFound {
return nil
}
return err
})
}
}

// OnMissing is a WalkOption adding a callback that will be triggered on a missing
// node.
func OnMissing(callback func(c cid.Cid)) WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(func(c cid.Cid, err error) error {
if err == ipld.ErrNotFound {
callback(c)
}
return err
})
}
}

// OnError is a WalkOption adding a custom error handler.
// If this handler return a nil error, the walk will continue.
func OnError(handler func(c cid.Cid, err error) error) WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(handler)
}
}

// WalkGraph will walk the dag in order (depth first) starting at the given root.
func Walk(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error {
func Walk(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool, options ...WalkOption) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}

return WalkDepth(ctx, getLinks, root, 0, visitDepth)
return WalkDepth(ctx, getLinks, c, visitDepth, options...)
}

// WalkDepth walks the dag starting at the given root and passes the current
// depth to a given visit function. The visit function can be used to limit DAG
// exploration.
func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error {
if !visit(root, depth) {
return nil
func WalkDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid, int) bool, options ...WalkOption) error {
opts := &walkOptions{}
for _, opt := range options {
opt(opts)
}

if opts.Concurrency > 1 {
return parallelWalkDepth(ctx, getLinks, c, visit, opts)
} else {
return sequentialWalkDepth(ctx, getLinks, c, 0, visit, opts)
}
}

func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool, options *walkOptions) error {
if depth != 0 || options.WithRoot {
if !visit(root, depth) {
return nil
}
}

links, err := getLinks(ctx, root)
if err != nil && options.ErrorHandler != nil {
err = options.ErrorHandler(root, err)
}
if err != nil {
return err
}

for _, lnk := range links {
if err := WalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit); err != nil {
if err := sequentialWalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit, options); err != nil {
return err
}
}
Expand Down Expand Up @@ -337,27 +451,7 @@ func (p *ProgressTracker) Value() int {
return p.Total
}

// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 32

// WalkParallel is equivalent to Walk *except* that it explores multiple paths
// in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallel(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}

return WalkParallelDepth(ctx, getLinks, c, 0, visitDepth)
}

// WalkParallelDepth is equivalent to WalkDepth *except* that it fetches
// children in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error {
func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *walkOptions) error {
type cidDepth struct {
cid cid.Cid
depth int
Expand All @@ -372,27 +466,37 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
out := make(chan *linksDepth)
done := make(chan struct{})

var setlk sync.Mutex
var visitlk sync.Mutex
var wg sync.WaitGroup

errChan := make(chan error)
fetchersCtx, cancel := context.WithCancel(ctx)
defer wg.Wait()
defer cancel()
for i := 0; i < FetchGraphConcurrency; i++ {
for i := 0; i < options.Concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for cdepth := range feed {
ci := cdepth.cid
depth := cdepth.depth

setlk.Lock()
shouldVisit := visit(ci, depth)
setlk.Unlock()
var shouldVisit bool

// bypass the root if needed
if depth != 0 || options.WithRoot {
visitlk.Lock()
shouldVisit = visit(ci, depth)
visitlk.Unlock()
} else {
shouldVisit = true
}

if shouldVisit {
links, err := getLinks(ctx, ci)
if err != nil && options.ErrorHandler != nil {
err = options.ErrorHandler(root, err)
}
if err != nil {
select {
case errChan <- err:
Expand Down Expand Up @@ -422,20 +526,21 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
defer close(feed)

send := feed
var todobuffer []*cidDepth
var todoQueue []*cidDepth
var inProgress int

next := &cidDepth{
cid: c,
depth: startDepth,
cid: root,
depth: 0,
}

for {
select {
case send <- next:
inProgress++
if len(todobuffer) > 0 {
next = todobuffer[0]
todobuffer = todobuffer[1:]
if len(todoQueue) > 0 {
next = todoQueue[0]
todoQueue = todoQueue[1:]
} else {
next = nil
send = nil
Expand All @@ -456,7 +561,7 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
next = cd
send = feed
} else {
todobuffer = append(todobuffer, cd)
todoQueue = append(todoQueue, cd)
}
}
case err := <-errChan:
Expand All @@ -466,7 +571,6 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
return ctx.Err()
}
}

}

var _ ipld.LinkGetter = &dagService{}
Expand Down
11 changes: 7 additions & 4 deletions merkledag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,11 @@ func makeTestDAG(t *testing.T, read io.Reader, ds ipld.DAGService) ipld.Node {
// Add a root referencing all created nodes
root := NodeWithData(nil)
for _, n := range nodes {
root.AddNodeLink(n.Cid().String(), n)
err := ds.Add(ctx, n)
err := root.AddNodeLink(n.Cid().String(), n)
if err != nil {
t.Fatal(err)
}
err = ds.Add(ctx, n)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -383,7 +386,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) {

}

err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF)
err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), visitF, WithRoot())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -736,7 +739,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
}

cset := cid.NewSet()
err = WalkParallel(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
err = Walk(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
if err == nil {
t.Fatal("this should have failed")
}
Expand Down

0 comments on commit 89dd2ad

Please sign in to comment.