diff --git a/collector.go b/collector.go index 6dc9b03..164053d 100644 --- a/collector.go +++ b/collector.go @@ -27,8 +27,8 @@ func NewCollector[T any](value func(T)) *Collector[T] { return &Collector[T]{han // Wait waits until the collector has finished processing. // // Deprecated: This method is now a noop; it is safe but unnecessary to call -// it. The state serviced by c is settled once all the goroutines writing to -// the collector have returned. It may be removed in a future version. +// it. Once all the tasks created from c have returned, any state accessed by +// the accumulator is settled. Wait may be removed in a future version. func (c *Collector[T]) Wait() {} // Task returns a Task wrapping a call to f. If f reports an error, that error diff --git a/taskgroup.go b/taskgroup.go index 8734f8f..236484b 100644 --- a/taskgroup.go +++ b/taskgroup.go @@ -4,30 +4,44 @@ // and respond to task errors. package taskgroup -import "sync" +import ( + "sync" + "sync/atomic" +) // A Task function is the basic unit of work in a Group. Errors reported by // tasks are collected and reported by the group. type Task func() error -// A Group manages a collection of cooperating goroutines. New tasks are added -// to the group with the Go method. Call the Wait method to wait for the tasks -// to complete. A zero value is ready for use, but must not be copied after its +// A Group manages a collection of cooperating goroutines. Add new tasks to +// the group with the Go method. Call the Wait method to wait for the tasks to +// complete. A zero value is ready for use, but must not be copied after its // first use. // // The group collects any errors returned by the tasks in the group. The first // non-nil error reported by any task (and not otherwise filtered) is returned // from the Wait method. type Group struct { - wg sync.WaitGroup // counter for active goroutines - err error // error returned from Wait + wg sync.WaitGroup // counter for active goroutines + onError ErrorFunc // called each time a task returns non-nil - setup sync.Once // set up and start the error collector - reset sync.Once // stop the error collector and set err + // active is nonzero when the group is "active", meaning there has been at + // least one call to Go since the group was created or the last Wait. + active atomic.Uint32 - onError func(error) error // called each time a task returns non-nil - errc chan<- error // errors generated by goroutines - edone chan struct{} // signals error completion + μ sync.Mutex // guards err + err error // error returned from Wait +} + +// activate resets the state of the group and marks it as active. This is +// triggered by adding a goroutine to an empty group. +func (g *Group) activate() { + g.μ.Lock() + defer g.μ.Unlock() + if g.active.Load() == 0 { // still inactive + g.err = nil + g.active.Store(1) + } } // New constructs a new empty group. If ef != nil, it is called for each error @@ -40,72 +54,58 @@ func New(ef ErrorFunc) *Group { return &Group{onError: ef} } // Go runs task in a new goroutine in g, and returns g to permit chaining. func (g *Group) Go(task Task) *Group { + if g.active.Load() == 0 { + g.activate() + } g.wg.Add(1) - g.init() - errc := g.errc go func() { defer g.wg.Done() if err := task(); err != nil { - errc <- err + g.handleError(err) } }() return g } -func (g *Group) init() { - // The first time a task is added to an otherwise clear group, set up the - // error collector goroutine. We don't do this in the constructor so that - // an unused group can be abandoned without orphaning a goroutine. - g.setup.Do(func() { - if g.onError == nil { - g.onError = func(e error) error { return e } - } - g.err = nil - g.edone = make(chan struct{}) - g.reset = sync.Once{} - - errc := make(chan error) - g.errc = errc - go func() { - defer close(g.edone) - for err := range errc { - e := g.onError(err) - if e != nil && g.err == nil { - g.err = e // capture the first error always - } - } - }() - }) -} - -func (g *Group) cleanup() { - g.reset.Do(func() { - g.wg.Wait() - if g.errc == nil { - return - } - close(g.errc) - <-g.edone - g.errc = nil - g.setup = sync.Once{} - }) +func (g *Group) handleError(err error) { + g.μ.Lock() + defer g.μ.Unlock() + e := g.onError.filter(err) + if e != nil && g.err == nil { + g.err = e // capture the first unfiltered error always + } } // Wait blocks until all the goroutines currently active in the group have -// returned, and all reported errors have been delivered to the callback. -// It returns the first non-nil error returned by any of the goroutines in the +// returned, and all reported errors have been delivered to the callback. It +// returns the first non-nil error reported by any of the goroutines in the // group and not filtered by an ErrorFunc. // -// It is safe to call Wait concurrently from multiple goroutines, but as with -// sync.WaitGroup no tasks can be added to g while any call to Wait is in -// progress. Once all Wait calls have returned, the group is ready for reuse. -func (g *Group) Wait() error { g.cleanup(); return g.err } +// As with sync.WaitGroup, new tasks can be added to g during a call to Wait +// only if there was already at least one task active when Wait was called. +// After Wait has returned, the group is ready for reuse. +// +// Wait may be called from at most one goroutine at a time. +func (g *Group) Wait() error { + g.wg.Wait() + g.μ.Lock() + defer g.μ.Unlock() + defer g.active.Store(0) + return g.err +} // An ErrorFunc is called by a group each time a task reports an error. Its // return value replaces the reported error, so the ErrorFunc can filter or // suppress errors by modifying or discarding the input error. type ErrorFunc func(error) error +func (ef ErrorFunc) filter(err error) error { + if ef == nil { + return err + } + return ef(err) +} + // Trigger creates an ErrorFunc that calls f each time a task reports an error. // The resulting ErrorFunc returns task errors unmodified. func Trigger(f func()) ErrorFunc { return func(e error) error { f(); return e } }