/
pool.go
301 lines (261 loc) · 8.69 KB
/
pool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
// Package pool implements simplified, single-stage flow.
//
// By default it runs with non-buffered channels and randomly distributed pool, i.e. incoming records send to one of
// the workers randomly. User creates pool.Worker, activates it by calling Go, submits input data via Submit. Go method
// returns Cursor allowing retrieval of results one-by-one (with cursor.Next) or reading them all with cursor.All method.
// Both cursor operation can be blocked as they read from the internal channel.
//
// After all inputs submitted user should call Close to indicate the completion.
//
// User may define ChunkFn returning key portion of the record and in this case record will be send to workers based on this key.
// The identical keys guaranteed to be send to the same worker. Such mode needed for stateful flows where each set of input
// records has to be processed sequentially and some state should be kept. Each worker gets an independent WorkerStore to keep
// some worker-local data.
//
// The actual worker function WorkerFn provided by user and will be executed by pool's goroutines.
// The worker will get an input record dispatched by the pool and could publish the result via SenderFn.
//
// Batch option sets size of internal buffer to minimize channel sends. Batch collects incoming records per worker and send
// them in as a slice.
//
// Error handling by default terminates the pool on the first error, unless ContinueOnError requested.
//
// Metrics can be retrieved and updated by user to keep some counters associated with any names.
//
// Workers pool should not be reused and can be activated only once. Thread safe, no additional locking needed.
package pool
import (
"context"
"errors"
"fmt"
"hash/crc32"
"math/rand"
"time"
"github.com/go-pkgz/flow"
"golang.org/x/sync/errgroup"
)
// Workers is a simple case of flow with a single stage only.
type Workers struct {
poolSize int // number of workers (goroutines)
batchSize int // size of batch send to workers
chunkFn func(interface{}) string
resChanSize int // size of responses channel
workerChanSize int // size of worker channels
workerFn WorkerFn // worker function
completeFn CompleteFn // completion callback function
continueOnError bool // don't terminate on first error
store []WorkerStore // workers store, per worker ID
buf [][]interface{}
workersCh []chan []interface{}
ctx context.Context
eg *errgroup.Group
}
// response wraps data and error
type response struct {
value interface{} // the actual data
err error // optional error
}
// WorkerStore defines interface for per-worker storage
type WorkerStore interface {
Set(key string, val interface{})
Get(key string) (interface{}, bool)
GetInt(key string) int
GetFloat(key string) float64
GetString(key string) string
GetBool(key string) bool
Keys() []string
Delete(key string)
}
type contextKey string
const widContextKey contextKey = "worker-id"
// WorkerFn processes input record inpRec and optionally sends results to sender func
type WorkerFn func(ctx context.Context, inpRec interface{}, sender SenderFn, store WorkerStore) error
// SenderFn func called by worker code to publish results
type SenderFn func(val interface{}) error
// CompleteFn processes input record inpRec and optionally sends response to respCh
type CompleteFn func(ctx context.Context, store WorkerStore) error
// New creates worker pool, can be activated once
func New(poolSize int, workerFn WorkerFn, options ...Option) *Workers {
if poolSize < 1 {
poolSize = 1
}
res := Workers{
poolSize: poolSize,
workersCh: make([]chan []interface{}, poolSize),
buf: make([][]interface{}, poolSize),
store: make([]WorkerStore, poolSize),
workerFn: workerFn,
completeFn: nil,
chunkFn: nil,
batchSize: 1,
resChanSize: 0,
workerChanSize: 0,
}
// apply options
for _, opt := range options {
opt(&res)
}
// initialize workers channels and batch buffers
for id := 0; id < poolSize; id++ {
res.workersCh[id] = make(chan []interface{}, res.workerChanSize)
if res.batchSize > 1 {
res.buf[id] = make([]interface{}, 0, poolSize)
}
res.store[id] = NewLocalStore()
}
rand.Seed(time.Now().UnixNano())
return &res
}
// Submit record to pool, can be blocked
func (p *Workers) Submit(v interface{}) {
// randomize distribution by default
id := rand.Intn(p.poolSize) //nolint gosec
if p.chunkFn != nil {
// chunked distribution
id = int(crc32.Checksum([]byte(p.chunkFn(v)), crc32.MakeTable(crc32.IEEE))) % p.poolSize
}
if p.batchSize <= 1 {
// skip all buffering if batch size is 1 or less
p.workersCh[id] <- append([]interface{}{}, v)
return
}
p.buf[id] = append(p.buf[id], v) // add to batch buffer
if len(p.buf[id]) >= p.batchSize {
// commit copy to workers
cp := make([]interface{}, len(p.buf[id]))
copy(cp, p.buf[id])
p.workersCh[id] <- cp
p.buf[id] = p.buf[id][:0] // reset size, keep capacity
}
}
// Go activates worker pool, closes result chan on completion
func (p *Workers) Go(ctx context.Context) (Cursor, error) {
if p.ctx != nil {
return Cursor{}, errors.New("workers poll already activated")
}
respCh := make(chan response, p.resChanSize)
p.ctx = context.WithValue(ctx, flow.MetricsContextKey, flow.NewMetrics())
var egCtx context.Context
p.eg, egCtx = errgroup.WithContext(ctx)
worker := func(id int, inCh chan []interface{}) func() error {
return func() error {
wCtx := context.WithValue(p.ctx, widContextKey, id)
for {
select {
case vv, ok := <-inCh:
if !ok { // input channel closed
e := p.flush(wCtx, id, respCh)
if !p.continueOnError {
return e
}
return nil
}
// read from the input slice
for _, v := range vv {
if err := p.workerFn(wCtx, v, p.sendResponseFn(wCtx, respCh), p.store[id]); err != nil {
e := fmt.Errorf("worker %d failed: %w", id, err)
if !p.continueOnError {
return e
}
respCh <- response{err: e}
}
}
case <-ctx.Done(): // parent context, passed by caller
respCh <- response{err: ctx.Err()}
return ctx.Err()
case <-egCtx.Done(): // worker context, set by errgroup
if !p.continueOnError {
return egCtx.Err()
}
return nil
}
}
}
}
// start all goroutines
for i := 0; i < p.poolSize; i++ {
p.eg.Go(worker(i, p.workersCh[i]))
}
go func() {
// wait for completion and close the response channel
if err := p.eg.Wait(); err != nil {
respCh <- response{err: err}
}
close(respCh)
}()
return Cursor{ch: respCh}, nil
}
// Metrics returns all user-defined counters from context.
func (p *Workers) Metrics() *flow.Metrics {
return Metrics(p.ctx)
}
// flush all records left in buffer to workers, called once for each worker
func (p *Workers) flush(ctx context.Context, id int, ch chan response) (err error) {
for _, v := range p.buf[id] {
if e := p.workerFn(ctx, v, p.sendResponseFn(ctx, ch), p.store[id]); e != nil {
err = fmt.Errorf("worker %d failed in flush: %w", id, e)
if !p.continueOnError {
return err
}
ch <- response{err: err}
}
}
p.buf[id] = p.buf[id][:0] // reset size to 0
// call completeFn for given worker id
if p.completeFn != nil {
if e := p.completeFn(ctx, p.store[id]); e != nil {
err = fmt.Errorf("complete func for %d failed: %w", id, e)
}
}
return err
}
// Close pool. Has to be called by consumer as the indication of "all records submitted".
// after this call poll can't be reused.
func (p *Workers) Close() {
for _, ch := range p.workersCh {
close(ch)
}
}
// Wait till workers completed and result channel closed. This can be used instead of the cursor
// in case if the result channel can be ignored and the goal is to wait for the completion.
func (p *Workers) Wait(ctx context.Context) (err error) {
doneCh := make(chan error)
go func() {
doneCh <- p.eg.Wait()
}()
for {
select {
case err := <-doneCh:
return err
case <-ctx.Done():
return ctx.Err()
}
}
}
// sendResponseFn makes sender func used by worker with the given context and response channel
func (p *Workers) sendResponseFn(ctx context.Context, respCh chan response) func(val interface{}) error {
return func(val interface{}) error {
select {
case respCh <- response{value: val}:
return nil
case <-ctx.Done():
return ctx.Err()
}
}
}
// Metrics return set of metrics from the context
func Metrics(ctx context.Context) *flow.Metrics {
res, ok := ctx.Value(flow.MetricsContextKey).(*flow.Metrics)
if !ok {
return flow.NewMetrics()
}
return res
}
// WorkerID returns worker ID from the context
func WorkerID(ctx context.Context) int {
cid, ok := ctx.Value(widContextKey).(int)
if !ok { // for non-parallel won't have any
cid = 0
}
return cid
}