Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Simplify the collection of task values and errors #5

Merged
merged 3 commits into from
Mar 19, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 18 additions & 13 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@

A `*taskgroup.Group` represents a group of goroutines working on related tasks.
New tasks can be added to the group at will, and the caller can wait until all
tasks are complete. Errors are automatically collected and delivered to a
user-provided callback in a single goroutine. This does not replace the full
tasks are complete. Errors are automatically collected and delivered
synchronously to a user-provided callback. This does not replace the full
generality of Go's built-in features, but it simplifies some of the plumbing
for common concurrent tasks.

Expand Down Expand Up @@ -229,7 +229,7 @@ start(task4) // blocks until one of the previous tasks is finished

## Solo Tasks
creachadair marked this conversation as resolved.
Show resolved Hide resolved

In some cases it is useful to start a solo background task to handle an
In some cases it is useful to start a single background task to handle an
isolated concern. For example, suppose we want to read a file into a buffer
while we take care of some other work. Rather than creating a whole group for
a single goroutine, we can create a solo task using the `Go` constructor.
Expand Down Expand Up @@ -270,13 +270,11 @@ var sum int
c := taskgroup.NewCollector(func(v int) { sum += v })
```

Internally, a `Collector` wraps a solo task and a channel to receive results.

The `Task` and `NoError` methods of the collector `c` can then be used to wrap
a function that reports a value. If the function reports an error, that error
is returned from the task as usual. Otherwise, its non-error value is given to
the callback. As in the above example, calls to the function are serialized so
that it is safe to access state without additional locking:
The `Task`, `NoError`, and `Report` methods of `c` wrap a function that yields
a value into a task. If the function reports an error, that error is returned
from the task as usual. Otherwise, its non-error value is given to the
accumulator callback. As in the above example, calls to the function are
serialized so that it is safe to access state without additional locking:

```go
// Report an error, no value for the collector.
Expand All @@ -291,14 +289,21 @@ g.Go(c.Task(func() (int, error) {

// Report a random integer to the collector.
g.Go(c.NoError(func() int { return rand.Intn(1000) })

// Report multiple values to the collector.
g.Go(c.Report(func(report func(int)) error {
report(10)
report(20)
report(30)
return nil
}))
```

Once all the tasks are done, call `Wait` to stop the collector and wait for it
to finish:
Once all the tasks derived from the collector are done, it is safe to access
the values accumulated by the callback:

```go
g.Wait() // wait for tasks to finish
c.Wait() // wait for the collector to finish

// Now you can access the values accumulated by c.
fmt.Println(sum)
Expand Down
58 changes: 58 additions & 0 deletions bench_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
package taskgroup_test

import (
"math/rand"
"sync"
"testing"
)

// A very rough benchmark comparing the performance of accumulating values with
// a separate goroutine via a channel, vs. accumulating them directly under a
// lock. The workload here is intentionally minimal, so the benchmark is
// measuring more or less just the overhead.

func BenchmarkChan(b *testing.B) {
ch := make(chan int)
done := make(chan struct{})
var total int
go func() {
defer close(done)
for v := range ch {
total += v
}
}()
b.ResetTimer() // discount the setup time.

var wg sync.WaitGroup
wg.Add(b.N)
for i := 0; i < b.N; i++ {
go func() {
defer wg.Done()
ch <- rand.Intn(1000)
}()
}
wg.Wait()
close(ch)
<-done
}

func BenchmarkLock(b *testing.B) {
var μ sync.Mutex
var total int
report := func(v int) {
μ.Lock()
defer μ.Unlock()
total += v
}
b.ResetTimer() // discount the setup time.

var wg sync.WaitGroup
wg.Add(b.N)
for i := 0; i < b.N; i++ {
go func() {
defer wg.Done()
report(rand.Intn(1000))
}()
}
wg.Wait()
}
75 changes: 46 additions & 29 deletions collector.go
Original file line number Diff line number Diff line change
@@ -1,37 +1,35 @@
package taskgroup

import "sync"

// A Collector collects values reported by task functions and delivers them to
// an accumulator function.
type Collector[T any] struct {
ch chan<- T
s *Single[error]
μ sync.Mutex
handle func(T)
}

// report delivers v to the callback under the lock.
func (c *Collector[T]) report(v T) {
c.μ.Lock()
defer c.μ.Unlock()
c.handle(v)
}

// NewCollector creates a new collector that delivers task values to the
// specified accumulator function. The collector serializes calls to value, so
// that it is safe for the function to access shared state without a lock. The
// caller must call Wait when the collector is no longer needed, even if it has
// not been used.
func NewCollector[T any](value func(T)) *Collector[T] {
ch := make(chan T)
s := Go(NoError(func() {
for v := range ch {
value(v)
}
}))
return &Collector[T]{ch: ch, s: s}
}
// that it is safe for the function to access shared state without a lock.
//
// The tasks created from a collector do not return until all the values
// reported by the underlying function have been processed by the accumulator.
func NewCollector[T any](value func(T)) *Collector[T] { return &Collector[T]{handle: value} }

// Wait stops the collector and blocks until it has finished processing.
// It is safe to call Wait multiple times from a single goroutine.
// Note that after Wait has been called, c is no longer valid.
func (c *Collector[T]) Wait() {
if c.ch != nil {
close(c.ch)
c.ch = nil
c.s.Wait()
}
}
// Wait waits until the collector has finished processing.
//
// Deprecated: This method is now a noop; it is safe but unnecessary to call
// it. Once all the tasks created from c have returned, any state accessed by
// the accumulator is settled. Wait may be removed in a future version.
func (c *Collector[T]) Wait() {}

// Task returns a Task wrapping a call to f. If f reports an error, that error
// is propagated as the return value of the task; otherwise, the non-error
Expand All @@ -42,21 +40,40 @@ func (c *Collector[T]) Task(f func() (T, error)) Task {
if err != nil {
return err
}
c.ch <- v
c.report(v)
return nil
}
}

// Report returns a task wrapping a call to f, which is passed a function that
// sends results to the accumulator. The report function does not return until
// the accumulator has finished processing the value.
func (c *Collector[T]) Report(f func(report func(T)) error) Task {
return func() error { return f(c.report) }
}

// Stream returns a task wrapping a call to f, which is passed a channel on
// which results can be sent to the accumulator.
// which results can be sent to the accumulator. Each call to Stream starts a
// goroutine to process the values on the channel.
//
// Note: f must not close its argument channel.
// Deprecated: Tasks that wish to deliver multiple values should use Report
// instead, which does not spawn a goroutine. This method may be removed in a
// future version.
func (c *Collector[T]) Stream(f func(chan<- T) error) Task {
return func() error { return f(c.ch) }
return func() error {
ch := make(chan T)
s := Go(NoError(func() {
for v := range ch {
c.report(v)
}
}))
defer func() { close(ch); s.Wait() }()
return f(ch)
}
}

// NoError returns a Task wrapping a call to f. The resulting task reports a
// nil error for all calls.
func (c *Collector[T]) NoError(f func() T) Task {
return NoError(func() { c.ch <- f() })
return NoError(func() { c.report(f()) })
}
22 changes: 9 additions & 13 deletions example_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -189,45 +189,41 @@ func ExampleCollector() {

// Wait for the searchers to finish, then signal the collector to stop.
g.Wait()
c.Wait()

// Now get the final result.
fmt.Println(total)
// Output:
// 325
}

func ExampleCollector_Stream() {
func ExampleCollector_Report() {
type val struct {
who string
v int
}
c := taskgroup.NewCollector(func(z val) { fmt.Println(z.who, z.v) })

err := taskgroup.New(nil).
// The Stream method passes its argument a channel where it may report
// multiple values to the collector.
Go(c.Stream(func(zs chan<- val) error {
// The Report method passes its argument a function to report multiple
// values to the collector.
Go(c.Report(func(report func(v val)) error {
for i := 0; i < 3; i++ {
zs <- val{"even", 2 * i}
report(val{"even", 2 * i})
}
return nil
})).
// Multiple streams are fine.
Go(c.Stream(func(zs chan<- val) error {
// Multiple reporters are fine.
Go(c.Report(func(report func(v val)) error {
for i := 0; i < 3; i++ {
zs <- val{"odd", 2*i + 1}
report(val{"odd", 2*i + 1})
}
// An error reported by a stream is propagated just like any other
// task error.
// An error from a reporter is propagated like any other task error.
return errors.New("no bueno")
})).
Wait()
if err == nil || err.Error() != "no bueno" {
log.Fatalf("Unexpected error: %v", err)
}

c.Wait()
// Unordered output:
// even 0
// odd 1
Expand Down
Loading