From 1a906a619306f47a0fa168dcdfc7208678faa54c Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Sun, 17 Mar 2024 22:14:48 -0700 Subject: [PATCH] Apply the collector simplifications to the Group too. Removing the goroutine also lets us get rid of the tick-tock dance with sync.Once values. Nice. Co-Authored-By: David Anderson --- taskgroup.go | 76 ++++++++++++++++++---------------------------------- 1 file changed, 26 insertions(+), 50 deletions(-) diff --git a/taskgroup.go b/taskgroup.go index e71d676..c283c34 100644 --- a/taskgroup.go +++ b/taskgroup.go @@ -10,24 +10,21 @@ import "sync" // tasks are collected and reported by the group. type Task func() error -// A Group manages a collection of cooperating goroutines. New tasks are added -// to the group with the Go method. Call the Wait method to wait for the tasks -// to complete. A zero value is ready for use, but must not be copied after its +// A Group manages a collection of cooperating goroutines. Add new tasks to +// the group with the Go method. Call the Wait method to wait for the tasks to +// complete. A zero value is ready for use, but must not be copied after its // first use. // // The group collects any errors returned by the tasks in the group. The first // non-nil error reported by any task (and not otherwise filtered) is returned // from the Wait method. type Group struct { - wg sync.WaitGroup // counter for active goroutines - err error // error returned from Wait - - setup sync.Once // set up and start the error collector - reset sync.Once // stop the error collector and set err - + wg sync.WaitGroup // counter for active goroutines onError func(error) error // called each time a task returns non-nil - errc chan<- error // errors generated by goroutines - edone chan struct{} // signals error completion + + μ sync.Mutex // guards the fields below + setup sync.Once // set up and start the error collector + err error // error returned from Wait } // New constructs a new empty group. If ef != nil, it is called for each error @@ -42,54 +39,27 @@ func New(ef ErrorFunc) *Group { return &Group{onError: ef} } func (g *Group) Go(task Task) *Group { g.wg.Add(1) g.init() - errc := g.errc go func() { defer g.wg.Done() if err := task(); err != nil { - errc <- err + g.handleError(err) } }() return g } -func (g *Group) init() { - // The first time a task is added to an otherwise clear group, set up the - // error collector goroutine. We don't do this in the constructor so that - // an unused group can be abandoned without orphaning a goroutine. - g.setup.Do(func() { - if g.onError == nil { - g.onError = func(e error) error { return e } - } - g.err = nil - g.edone = make(chan struct{}) - g.reset = sync.Once{} - - errc := make(chan error) - g.errc = errc - go func() { - defer close(g.edone) - for err := range errc { - e := g.onError(err) - if e != nil && g.err == nil { - g.err = e // capture the first error always - } - } - }() - }) +func (g *Group) handleError(err error) { + g.μ.Lock() + defer g.μ.Unlock() + if g.onError != nil { + err = g.onError(err) + } + if err != nil && g.err == nil { + g.err = err // capture the first error always + } } -func (g *Group) cleanup() { - g.reset.Do(func() { - g.wg.Wait() - if g.errc == nil { - return - } - close(g.errc) - <-g.edone - g.errc = nil - g.setup = sync.Once{} - }) -} +func (g *Group) init() { g.setup.Do(func() { g.err = nil }) } // Wait blocks until all the goroutines currently active in the group have // returned, and all reported errors have been delivered to the callback. @@ -100,7 +70,13 @@ func (g *Group) cleanup() { // sync.WaitGroup, new tasks can be added to the group only if there is at // least one task active that started before all active Wait calls. Once all // Wait calls have returned, the group is ready for reuse. -func (g *Group) Wait() error { g.cleanup(); return g.err } +func (g *Group) Wait() error { + g.wg.Wait() + g.μ.Lock() + defer g.μ.Unlock() + g.setup = sync.Once{} + return g.err +} // An ErrorFunc is called by a group each time a task reports an error. Its // return value replaces the reported error, so the ErrorFunc can filter or