-
Notifications
You must be signed in to change notification settings - Fork 462
/
taskfn.go
182 lines (152 loc) · 4.02 KB
/
taskfn.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
// SPDX-FileCopyrightText: 2024 SAP SE or an SAP affiliate company and Gardener contributors
//
// SPDX-License-Identifier: Apache-2.0
package flow
import (
"context"
"sync"
"time"
"github.com/hashicorp/go-multierror"
"github.com/gardener/gardener/pkg/utils/retry"
)
var (
// ContextWithTimeout is context.WithTimeout. Exposed for testing.
ContextWithTimeout = context.WithTimeout
)
// TaskFn is a payload function of a task.
type TaskFn func(ctx context.Context) error
// RecoverFn is a function that can recover an error.
type RecoverFn func(ctx context.Context, err error) error
// Timeout returns a TaskFn that is bound to a context which times out.
func (t TaskFn) Timeout(timeout time.Duration) TaskFn {
return func(ctx context.Context) error {
var cancel func()
ctx, cancel = context.WithTimeout(ctx, timeout)
defer cancel()
return t(ctx)
}
}
// RetryUntilTimeout returns a TaskFn that is retried until the timeout is reached.
func (t TaskFn) RetryUntilTimeout(interval, timeout time.Duration) TaskFn {
return func(ctx context.Context) error {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
return retry.Until(ctx, interval, func(ctx context.Context) (done bool, err error) {
if err := t(ctx); err != nil {
return retry.MinorError(err)
}
return retry.Ok()
})
}
}
// ToRecoverFn converts the TaskFn to a RecoverFn that ignores the incoming error.
func (t TaskFn) ToRecoverFn() RecoverFn {
return func(ctx context.Context, _ error) error {
return t(ctx)
}
}
// Recover creates a new TaskFn that recovers an error with the given RecoverFn.
func (t TaskFn) Recover(recoverFn RecoverFn) TaskFn {
return func(ctx context.Context) error {
if err := t(ctx); err != nil {
if ctx.Err() != nil {
return err
}
return recoverFn(ctx, err)
}
return nil
}
}
// Sequential runs the given TaskFns sequentially.
func Sequential(fns ...TaskFn) TaskFn {
return func(ctx context.Context) error {
for _, fn := range fns {
if err := fn(ctx); err != nil {
return err
}
if err := ctx.Err(); err != nil {
return err
}
}
return nil
}
}
// ParallelN returns a function that runs the given TaskFns in parallel by spawning N workers,
// collecting their errors in a multierror. If N <= 0, then N will be defaulted to len(fns).
func ParallelN(n int, fns ...TaskFn) TaskFn {
workers := n
if n <= 0 {
workers = len(fns)
}
return func(ctx context.Context) error {
var (
wg sync.WaitGroup
fnsCh = make(chan TaskFn)
errCh = make(chan error)
result error
)
for i := 0; i < workers; i++ {
wg.Add(1)
go func() {
for fn := range fnsCh {
fn := fn
errCh <- fn(ctx)
}
wg.Done()
}()
}
go func() {
for _, f := range fns {
fnsCh <- f
}
close(fnsCh)
}()
go func() {
defer close(errCh)
wg.Wait()
}()
for err := range errCh {
if err != nil {
result = multierror.Append(result, err)
}
}
return result
}
}
// Parallel runs the given TaskFns in parallel, collecting their errors in a multierror.
func Parallel(fns ...TaskFn) TaskFn {
return ParallelN(len(fns), fns...)
}
// ParallelExitOnError runs the given TaskFns in parallel and stops execution as soon as one TaskFn returns an error.
func ParallelExitOnError(fns ...TaskFn) TaskFn {
return func(ctx context.Context) error {
var (
wg sync.WaitGroup
// make sure all other goroutines can send their result if one task fails to not block and leak them
errors = make(chan error, len(fns))
subCtx, cancel = context.WithCancel(ctx)
)
// cancel any remaining parallel tasks on error,
// though we will not wait until all tasks have finished
defer cancel()
for _, fn := range fns {
t := fn
wg.Add(1)
go func() {
defer wg.Done()
errors <- t(subCtx)
}()
}
go func() {
// close errors channel as soon as all tasks finished to stop range operator in for loop reading from channel
defer close(errors)
wg.Wait()
}()
for err := range errors {
if err != nil {
return err
}
}
return nil
}
}