Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

rework the graph walking functions with functional options #42

Merged
merged 3 commits into from
Jul 22, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
190 changes: 147 additions & 43 deletions merkledag.go
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ func FetchGraph(ctx context.Context, root cid.Cid, serv ipld.DAGService) error {
}

// FetchGraphWithDepthLimit fetches all nodes that are children to the given
// node down to the given depth. maxDetph=0 means "only fetch root",
// node down to the given depth. maxDepth=0 means "only fetch root",
// maxDepth=1 means "fetch root and its direct children" and so on...
// maxDepth=-1 means unlimited.
func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, serv ipld.DAGService) error {
Expand Down Expand Up @@ -195,9 +195,10 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
return false
}

// If we have a ProgressTracker, we wrap the visit function to handle it
v, _ := ctx.Value(progressContextKey).(*ProgressTracker)
if v == nil {
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visit)
return WalkDepth(ctx, GetLinksDirect(ng), root, visit, Concurrent(), WithRoot())
}

visitProgress := func(c cid.Cid, depth int) bool {
Expand All @@ -207,7 +208,7 @@ func FetchGraphWithDepthLimit(ctx context.Context, root cid.Cid, depthLim int, s
}
return false
}
return WalkParallelDepth(ctx, GetLinksDirect(ng), root, 0, visitProgress)
return WalkDepth(ctx, GetLinksDirect(ng), root, visitProgress, Concurrent(), WithRoot())
}

// GetMany gets many nodes from the DAG at once.
Expand Down Expand Up @@ -281,30 +282,143 @@ func GetLinksWithDAG(ng ipld.NodeGetter) GetLinks {
}
}

// defaultConcurrentFetch is the default maximum number of concurrent fetches
// that 'fetchNodes' will start at a time
const defaultConcurrentFetch = 32

// walkOptions represent the parameters of a graph walking algorithm
type walkOptions struct {
WithRoot bool
Concurrency int
ErrorHandler func(c cid.Cid, err error) error
}

// WalkOption is a setter for walkOptions
type WalkOption func(*walkOptions)

func (wo *walkOptions) addHandler(handler func(c cid.Cid, err error) error) {
if wo.ErrorHandler != nil {
wo.ErrorHandler = func(c cid.Cid, err error) error {
return handler(c, wo.ErrorHandler(c, err))
}
} else {
wo.ErrorHandler = handler
}
}

// WithRoot is a WalkOption indicating that the root node should be visited
func WithRoot() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.WithRoot = true
}
}

// Concurrent is a WalkOption indicating that node fetching should be done in
// parallel, with the default concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrent() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.Concurrency = defaultConcurrentFetch
}
}

// Concurrency is a WalkOption indicating that node fetching should be done in
// parallel, with a specific concurrency factor.
// NOTE: When using that option, the walk order is *not* guarantee.
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func Concurrency(worker int) WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.Concurrency = worker
}
}

// IgnoreErrors is a WalkOption indicating that the walk should attempt to
// continue even when an error occur.
func IgnoreErrors() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(func(c cid.Cid, err error) error {
return nil
})
}
}

// IgnoreMissing is a WalkOption indicating that the walk should continue when
// a node is missing.
func IgnoreMissing() WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(func(c cid.Cid, err error) error {
if err == ipld.ErrNotFound {
return nil
}
return err
})
}
}

// OnMissing is a WalkOption adding a callback that will be triggered on a missing
// node.
func OnMissing(callback func(c cid.Cid)) WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(func(c cid.Cid, err error) error {
if err == ipld.ErrNotFound {
callback(c)
}
return err
})
}
}

// OnError is a WalkOption adding a custom error handler.
// If this handler return a nil error, the walk will continue.
func OnError(handler func(c cid.Cid, err error) error) WalkOption {
return func(walkOptions *walkOptions) {
walkOptions.addHandler(handler)
}
}

// WalkGraph will walk the dag in order (depth first) starting at the given root.
func Walk(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid) bool) error {
func Walk(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool, options ...WalkOption) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}

return WalkDepth(ctx, getLinks, root, 0, visitDepth)
return WalkDepth(ctx, getLinks, c, visitDepth, options...)
}

// WalkDepth walks the dag starting at the given root and passes the current
// depth to a given visit function. The visit function can be used to limit DAG
// exploration.
func WalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool) error {
if !visit(root, depth) {
return nil
func WalkDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid, int) bool, options ...WalkOption) error {
opts := &walkOptions{}
for _, opt := range options {
opt(opts)
}

if opts.Concurrency > 1 {
return parallelWalkDepth(ctx, getLinks, c, visit, opts)
} else {
return sequentialWalkDepth(ctx, getLinks, c, 0, visit, opts)
}
}

func sequentialWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, depth int, visit func(cid.Cid, int) bool, options *walkOptions) error {
if depth != 0 || options.WithRoot {
if !visit(root, depth) {
return nil
}
}

links, err := getLinks(ctx, root)
if err != nil && options.ErrorHandler != nil {
err = options.ErrorHandler(root, err)
}
if err != nil {
return err
}

for _, lnk := range links {
if err := WalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit); err != nil {
if err := sequentialWalkDepth(ctx, getLinks, lnk.Cid, depth+1, visit, options); err != nil {
return err
}
}
Expand Down Expand Up @@ -337,27 +451,7 @@ func (p *ProgressTracker) Value() int {
return p.Total
}

// FetchGraphConcurrency is total number of concurrent fetches that
// 'fetchNodes' will start at a time
var FetchGraphConcurrency = 32

// WalkParallel is equivalent to Walk *except* that it explores multiple paths
// in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallel(ctx context.Context, getLinks GetLinks, c cid.Cid, visit func(cid.Cid) bool) error {
visitDepth := func(c cid.Cid, depth int) bool {
return visit(c)
}

return WalkParallelDepth(ctx, getLinks, c, 0, visitDepth)
}

// WalkParallelDepth is equivalent to WalkDepth *except* that it fetches
// children in parallel.
//
// NOTE: It *does not* make multiple concurrent calls to the passed `visit` function.
func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startDepth int, visit func(cid.Cid, int) bool) error {
func parallelWalkDepth(ctx context.Context, getLinks GetLinks, root cid.Cid, visit func(cid.Cid, int) bool, options *walkOptions) error {
type cidDepth struct {
cid cid.Cid
depth int
Expand All @@ -372,27 +466,37 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
out := make(chan *linksDepth)
done := make(chan struct{})

var setlk sync.Mutex
var visitlk sync.Mutex
var wg sync.WaitGroup

errChan := make(chan error)
fetchersCtx, cancel := context.WithCancel(ctx)
defer wg.Wait()
defer cancel()
for i := 0; i < FetchGraphConcurrency; i++ {
for i := 0; i < options.Concurrency; i++ {
wg.Add(1)
go func() {
defer wg.Done()
for cdepth := range feed {
ci := cdepth.cid
depth := cdepth.depth

setlk.Lock()
shouldVisit := visit(ci, depth)
setlk.Unlock()
var shouldVisit bool

// bypass the root if needed
if depth != 0 || options.WithRoot {
visitlk.Lock()
shouldVisit = visit(ci, depth)
visitlk.Unlock()
} else {
shouldVisit = true
}

if shouldVisit {
links, err := getLinks(ctx, ci)
if err != nil && options.ErrorHandler != nil {
err = options.ErrorHandler(root, err)
}
if err != nil {
select {
case errChan <- err:
Expand Down Expand Up @@ -422,20 +526,21 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
defer close(feed)

send := feed
var todobuffer []*cidDepth
var todoQueue []*cidDepth
var inProgress int

next := &cidDepth{
cid: c,
depth: startDepth,
cid: root,
depth: 0,
}

for {
select {
case send <- next:
inProgress++
if len(todobuffer) > 0 {
next = todobuffer[0]
todobuffer = todobuffer[1:]
if len(todoQueue) > 0 {
next = todoQueue[0]
todoQueue = todoQueue[1:]
} else {
next = nil
send = nil
Expand All @@ -456,7 +561,7 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
next = cd
send = feed
} else {
todobuffer = append(todobuffer, cd)
todoQueue = append(todoQueue, cd)
}
}
case err := <-errChan:
Expand All @@ -466,7 +571,6 @@ func WalkParallelDepth(ctx context.Context, getLinks GetLinks, c cid.Cid, startD
return ctx.Err()
}
}

}

var _ ipld.LinkGetter = &dagService{}
Expand Down
11 changes: 7 additions & 4 deletions merkledag_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,8 +203,11 @@ func makeTestDAG(t *testing.T, read io.Reader, ds ipld.DAGService) ipld.Node {
// Add a root referencing all created nodes
root := NodeWithData(nil)
for _, n := range nodes {
root.AddNodeLink(n.Cid().String(), n)
err := ds.Add(ctx, n)
err := root.AddNodeLink(n.Cid().String(), n)
if err != nil {
t.Fatal(err)
}
err = ds.Add(ctx, n)
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -383,7 +386,7 @@ func TestFetchGraphWithDepthLimit(t *testing.T) {

}

err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), 0, visitF)
err = WalkDepth(context.Background(), offlineDS.GetLinks, root.Cid(), visitF, WithRoot())
if err != nil {
t.Fatal(err)
}
Expand Down Expand Up @@ -736,7 +739,7 @@ func TestEnumerateAsyncFailsNotFound(t *testing.T) {
}

cset := cid.NewSet()
err = WalkParallel(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
err = Walk(ctx, GetLinksDirect(ds), parent.Cid(), cset.Visit)
if err == nil {
t.Fatal("this should have failed")
}
Expand Down