diff --git a/do_all.go b/do_all.go index 31614de..687fb90 100644 --- a/do_all.go +++ b/do_all.go @@ -2,8 +2,6 @@ package workgroup import "errors" -type void = struct{} - // DoAll starts n concurrent workers (or GOMAXPROCS workers if n < 1) // and processes each initial input as a task. // Errors returned by a task do not halt execution, @@ -11,15 +9,28 @@ type void = struct{} // If a task panics during execution, // the panic will be caught and rethrown in the main Goroutine. func DoAll[Input any](n int, items []Input, task func(Input) error) error { + var recovered any errs := make([]error, 0, len(items)) - DoTasks(n, func(in Input) (void, error) { - return void{}, task(in) - }, func(_ Input, _ void, err error) ([]Input, bool) { + runner := func(in Input) (r any, err error) { + defer func() { + r = recover() + }() + err = task(in) + return + } + manager := func(_ Input, r any, err error) ([]Input, bool) { + if r != nil { + recovered = r + } if err != nil { errs = append(errs, err) } return nil, true - }, items...) + } + DoTasks(n, runner, manager, items...) + if recovered != nil { + panic(recovered) + } return errors.Join(errs...) } diff --git a/do_all_test.go b/do_all_test.go new file mode 100644 index 0000000..d02f72f --- /dev/null +++ b/do_all_test.go @@ -0,0 +1,7 @@ +package workgroup_test + +import "testing" + +func TestDoAll_err(t *testing.T) { + +} diff --git a/do_tasks_test.go b/do_tasks_test.go new file mode 100644 index 0000000..0482d07 --- /dev/null +++ b/do_tasks_test.go @@ -0,0 +1,48 @@ +package workgroup_test + +import ( + "errors" + "fmt" + "testing" + "time" + + "github.com/carlmjohnson/workgroup" +) + +func TestDoTasks_drainage(t *testing.T) { + const sleepTime = 10 * time.Millisecond + b := false + task := func(n int) (int, error) { + if n == 1 { + return 0, errors.New("text string") + } + time.Sleep(sleepTime) + b = true + return 0, nil + } + start := time.Now() + m := map[int]struct { + int + error + }{} + manager := func(in, out int, err error) ([]int, bool) { + m[in] = struct { + int + error + }{out, err} + if err != nil { + return nil, false + } + return nil, true + } + workgroup.DoTasks(5, task, manager, 0, 1) + if s := fmt.Sprint(m); s != "map[1:text string]" { + t.Fatal(s) + } + if time.Since(start) < sleepTime { + t.Fatal("didn't sleep enough") + } + if !b { + t.Fatal("didn't finish") + } +} diff --git a/do_test.go b/panic_test.go similarity index 59% rename from do_test.go rename to panic_test.go index 736c500..a8728d6 100644 --- a/do_test.go +++ b/panic_test.go @@ -1,15 +1,21 @@ package workgroup_test import ( - "errors" "fmt" "sync/atomic" "testing" - "time" "github.com/carlmjohnson/workgroup" ) +func try(f func()) (r any) { + defer func() { + r = recover() + }() + f() + return +} + func TestDoTasks_panic(t *testing.T) { task := func(n int) (int, error) { if n == 3 { @@ -22,13 +28,9 @@ func TestDoTasks_panic(t *testing.T) { triples = append(triples, triple) return nil, true } - var r any - func() { - defer func() { - r = recover() - }() + r := try(func() { workgroup.DoTasks(1, task, manager, 1, 2, 3, 4) - }() + }) if r == nil { t.Fatal("should have panicked") } @@ -44,12 +46,8 @@ func TestDoAll_panic(t *testing.T) { var ( n atomic.Int64 err error - r any ) - func() { - defer func() { - r = recover() - }() + r := try(func() { err = workgroup.DoAll(1, []int64{1, 2, 3}, func(delta int64) error { if delta == 2 { @@ -58,7 +56,7 @@ func TestDoAll_panic(t *testing.T) { n.Add(delta) return nil }) - }() + }) if err != nil { t.Fatal("should have panicked") } @@ -68,7 +66,7 @@ func TestDoAll_panic(t *testing.T) { if r != "boom" { t.Fatal(r) } - if n.Load() != 1 { + if n.Load() != 4 { t.Fatal(n.Load()) } } @@ -77,10 +75,8 @@ func TestDo_panic(t *testing.T) { var ( n atomic.Int64 err error - r any ) - func() { - defer func() { r = recover() }() + r := try(func() { err = workgroup.Do(1, func() error { n.Add(1) @@ -93,7 +89,7 @@ func TestDo_panic(t *testing.T) { n.Add(1) return nil }) - }() + }) if err != nil { t.Fatal("should have panicked") } @@ -103,45 +99,7 @@ func TestDo_panic(t *testing.T) { if r != "boom" { t.Fatal(r) } - if n.Load() != 1 { + if n.Load() != 2 { t.Fatal(n.Load()) } } - -func TestDoTasks_drainage(t *testing.T) { - const sleepTime = 10 * time.Millisecond - b := false - task := func(n int) (int, error) { - if n == 1 { - return 0, errors.New("text string") - } - time.Sleep(sleepTime) - b = true - return 0, nil - } - start := time.Now() - m := map[int]struct { - int - error - }{} - manager := func(in, out int, err error) ([]int, bool) { - m[in] = struct { - int - error - }{out, err} - if err != nil { - return nil, false - } - return nil, true - } - workgroup.DoTasks(5, task, manager, 0, 1) - if s := fmt.Sprint(m); s != "map[1:text string]" { - t.Fatal(s) - } - if time.Since(start) < sleepTime { - t.Fatal("didn't sleep enough") - } - if !b { - t.Fatal("didn't finish") - } -}