/
workerpool.go
84 lines (73 loc) · 1.52 KB
/
workerpool.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
package utils
import (
"context"
"errors"
"sync"
"golang.org/x/sync/semaphore"
)
// ErrTaskAborted error thrown when a taks is aborted
var ErrTaskAborted = errors.New("Aborted")
// Task is a representation of a task
type Task struct {
fn func() error
errC chan error
}
// WorkerPool a pool of workers able to perform multiple tasks in parallel
type WorkerPool struct {
workers int
taskQ chan Task
wg sync.WaitGroup
sem *semaphore.Weighted
failed bool
}
// NewWorkerPool instantiate a new worker pool
func NewWorkerPool(workers int) *WorkerPool {
taskQ := make(chan Task)
wp := &WorkerPool{
workers: workers,
taskQ: taskQ,
sem: semaphore.NewWeighted(int64(workers)),
failed: false,
}
wp.wg.Add(workers)
for i := 0; i < workers; i++ {
go func() {
defer wp.wg.Done()
for task := range taskQ {
// If the worker pool has been aborted, we skip all tasks in the queue.
if wp.failed {
task.errC <- ErrTaskAborted
wp.sem.Release(1)
continue
}
err := task.fn()
wp.sem.Release(1)
if err != nil {
wp.failed = true
}
task.errC <- err
}
}()
}
return wp
}
// Exec enqueues one task into the worker pool
func (wp *WorkerPool) Exec(f func() error) chan error {
errC := make(chan error, 1)
if wp.failed {
errC <- ErrTaskAborted
return errC
}
wp.sem.Acquire(context.Background(), 1)
t := Task{
fn: f,
errC: errC,
}
wp.taskQ <- t
return t.errC
}
// Close the worker pool
func (wp *WorkerPool) Close() {
close(wp.taskQ)
wp.wg.Wait()
}