/
task.go
179 lines (145 loc) · 3.93 KB
/
task.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
package utils
import (
"fmt"
"sync"
"sync/atomic"
"time"
"github.com/mevansam/goutils/logger"
)
type TaskDispatcher struct {
putTimeout time.Duration
dispatchQueue chan *task
queued sync.WaitGroup
stop int32
}
type task struct {
td *TaskDispatcher
name string
numTries int
taskData interface{}
taskFn func(inData interface{}) (outData interface{}, err error)
successFn func(outData interface{})
errorFn func(err error, outData interface{})
}
func NewTaskDispatcher(bufferSize int, putTimeout time.Duration) *TaskDispatcher {
return &TaskDispatcher{
putTimeout: putTimeout,
dispatchQueue: make(chan *task, bufferSize),
stop: 0,
}
}
func (td *TaskDispatcher) Start(numWorkers int) {
for i := 0; i < numWorkers; i++ {
workerIndex := i
go func() {
logger.TraceMessage("TaskDispatcher.Worker(): Running task worker '%d'.", workerIndex)
var (
err error
outData interface{}
)
for t := range td.dispatchQueue {
logger.TraceMessage(
"TaskDispatcher.Worker(): Running task '%s' with data: %# v",
t.name, t.taskData,
)
if outData, err = t.taskFn(t.taskData); err != nil {
if atomic.LoadInt32(&td.stop) == 1 || t.numTries == 0 || outData == nil {
// if no more retries and still
// having errors then inform the
// task owner with the error and
// remaining unprocessed data
if t.errorFn != nil {
t.errorFn(err, outData)
} else {
logger.TraceMessage(
"TaskDispatcher.Worker(): Task '%s' execution failed with output data: %# v",
t.name, outData,
)
logger.ErrorMessage(
"TaskDispatcher.Worker(): Task '%s' execution failed with error: %s",
t.name, err.Error(),
)
}
} else {
logger.TraceMessage(
"TaskDispatcher.Worker(): Task '%s' execution failed with output data: %# v",
t.name, outData,
)
logger.ErrorMessage(
"TaskDispatcher.Worker(): Task '%s' execution failed with error: %s; Will retry '%d' more times.",
t.name, err.Error(), t.numTries,
)
// rerun task fn until task
// completes or the retry counter
// is done
t.numTries--
t.taskData = outData
if dispatchErr := t.td.dispatch(t); dispatchErr != nil {
logger.ErrorMessage("TaskDispatcher.Worker(): %s.", dispatchErr.Error())
if t.errorFn != nil {
t.errorFn(err, outData)
}
}
}
} else {
if (t.successFn != nil) {
t.successFn(outData)
} else {
logger.TraceMessage(
"TaskDispatcher.Worker(): Task '%s' execution succeeded with output data %# v",
t.name, outData,
)
}
}
td.queued.Done()
}
logger.TraceMessage("TaskDispatcher.Worker(): Task worker '%d' done.", workerIndex)
}()
}
}
func (td *TaskDispatcher) Stop() {
// signal to stop retrying
atomic.StoreInt32(&td.stop, 1)
// wait for all tasks to complete or error out
td.queued.Wait()
close(td.dispatchQueue)
}
func (td *TaskDispatcher) RunTask(
name string,
taskFn func(inData interface{}) (outData interface{}, err error),
) *task {
return &task{
td: td,
name: name,
taskFn: taskFn,
}
}
func (td *TaskDispatcher) dispatch(t *task) error {
td.queued.Add(1)
select {
case td.dispatchQueue <- t:
return nil
case <- time.After(td.putTimeout * time.Millisecond):
td.queued.Done()
return fmt.Errorf("timed out attempting to dispatch task %s", t.name)
}
}
func (t *task) WithData(taskData interface{}) *task {
t.taskData = taskData
return t
}
func (t *task) OnSuccess(successFn func(outData interface{})) *task {
t.successFn = successFn
return t
}
func (t *task) OnError(errorFn func(err error, unprocessedData interface{})) *task {
t.errorFn = errorFn
return t
}
func (t *task) Once() error {
return t.td.dispatch((t))
}
func (t *task) WithRetries(numRetries int) error {
t.numTries = numRetries
return t.td.dispatch((t))
}