-
Notifications
You must be signed in to change notification settings - Fork 0
/
store.go
480 lines (422 loc) · 12.3 KB
/
store.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
package taskqueue
import (
"bytes"
"context"
"crypto/rand"
"database/sql"
_ "embed"
"encoding/hex"
"errors"
"fmt"
"html/template"
"net/http"
"strings"
"time"
// SQLite driver.
_ "github.com/mattn/go-sqlite3"
)
// OpenTaskQueue returns a task queue store implementation.
func OpenTaskQueue(dbpath string) (*Store, error) {
db, err := sql.Open("sqlite3", dbpath)
if err != nil {
return nil, fmt.Errorf("open db: %w", err)
}
db.SetMaxOpenConns(1) // Because SQLite.
db.SetMaxIdleConns(1)
db.SetConnMaxIdleTime(time.Second)
db.SetConnMaxLifetime(time.Second * 3)
if err := migrate(db); err != nil {
return nil, fmt.Errorf("migration: %w", err)
}
return &Store{db: db}, nil
}
//go:embed migrations.sql
var migrations string
func migrate(db *sql.DB) error {
for _, query := range strings.Split(migrations, "\n---\n") {
if _, err := db.Exec(query); err != nil {
return fmt.Errorf("%w: %s", err, query)
}
}
return nil
}
type Store struct {
db *sql.DB
}
// Close the store and free all resources.
func (s *Store) Close() error {
return s.db.Close()
}
// Push one or more tasks to the queue. This is an atomic operation.
func (s *Store) Push(ctx context.Context, tasks []TaskReq) ([]string, error) {
if len(tasks) == 0 {
return nil, nil
}
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("start transaction: %w", err)
}
defer tx.Rollback()
now := currentTime()
taskIDs := make([]string, 0, len(tasks))
for _, t := range tasks {
taskID := generateID()
taskIDs = append(taskIDs, taskID)
payload := t.Payload
if payload == nil {
payload = emptyPayload
}
_, err := tx.ExecContext(ctx, `
INSERT INTO tasks (task_id, name, payload, retry, timeout, execute_at, created_at)
VALUES (@task_id, @name, @payload, @retry, @timeout, @execute_at, @created_at)
`,
sql.Named("task_id", taskID),
sql.Named("name", t.Name),
sql.Named("payload", payload),
sql.Named("retry", t.Retry),
sql.Named("timeout", t.Timeout/time.Second),
sql.Named("execute_at", now.Add(t.ExecuteIn).Unix()),
sql.Named("created_at", now.Unix()),
)
if err != nil {
return nil, fmt.Errorf("insert: %w", err)
}
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
return taskIDs, nil
}
// TaskReq represents a task creation request. If successfully processed,
// results in a task being queued.
type TaskReq struct {
Name string
Payload []byte
Retry uint
ExecuteIn time.Duration
Timeout time.Duration
}
var emptyPayload = make([]byte, 0)
// Delete removes task with given ID from the queue if present and not locked
// for processing.
func (s *Store) Delete(ctx context.Context, taskID string) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("start transaction: %w", err)
}
defer tx.Rollback()
var ok bool
switch err := tx.QueryRowContext(ctx, `SELECT 1 FROM acquired WHERE task_id = ? LIMIT 1`, taskID).Scan(&ok); {
case err == nil && ok:
return fmt.Errorf("task is being processed: %w", ErrLocked)
case errors.Is(err, sql.ErrNoRows):
// All good.
default:
return fmt.Errorf("check if task is acquired: %w", err)
}
res, err := tx.ExecContext(ctx, `DELETE FROM tasks WHERE task_id = ?`, taskID)
if err != nil {
return fmt.Errorf("delete task: %w", err)
}
if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("rows affected: %w", err)
} else if n != 1 {
return ErrNotFound
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
func (s *Store) Pull(ctx context.Context) (*Task, error) {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return nil, fmt.Errorf("start transaction: %w", err)
}
defer tx.Rollback()
now := currentTime()
row := tx.QueryRowContext(ctx, `
SELECT task_id, name, payload, timeout
FROM tasks
WHERE execute_at <= ?
AND task_id NOT IN (SELECT task_id FROM acquired)
ORDER BY execute_at ASC
LIMIT 1
`, now.Unix())
var task Task
var timeout int64
if err := row.Scan(&task.TaskID, &task.Name, &task.Payload, &timeout); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, ErrEmpty
}
return nil, fmt.Errorf("scan task: %w", err)
}
task.Timeout = time.Duration(timeout) * time.Second
_, err = tx.ExecContext(ctx, `
INSERT INTO acquired (task_id, created_at)
VALUES (?, ?)
`, task.TaskID, now.Unix())
if err != nil {
return nil, fmt.Errorf("insert acquire task: %w", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("commit transaction: %w", err)
}
return &task, nil
}
// Task represents a single task (job) acquired from the queue.
type Task struct {
TaskID string
Name string
Payload []byte
Timeout time.Duration
}
func (s *Store) Ack(ctx context.Context, taskID string) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("start transaction: %w", err)
}
defer tx.Rollback()
res, err := tx.ExecContext(ctx, `
DELETE FROM acquired WHERE task_id = ?
`, taskID)
if err != nil {
return fmt.Errorf("delete acquired lock: %w", err)
}
if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("rows affected: %w", err)
} else if n != 1 {
return fmt.Errorf("task %q is not acquired", taskID)
}
if _, err := tx.ExecContext(ctx, `DELETE FROM tasks WHERE task_id = ?`, taskID); err != nil {
return fmt.Errorf("delete from tasks list: %w", err)
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
func (s *Store) Nack(ctx context.Context, taskID string, reason string) error {
tx, err := s.db.BeginTx(ctx, nil)
if err != nil {
return fmt.Errorf("start transaction: %w", err)
}
defer tx.Rollback()
now := currentTime()
_, err = tx.ExecContext(ctx, `
INSERT INTO failures (task_id, created_at, description)
VALUES (?, ?, ?)
`, taskID, now.Unix(), reason)
if err != nil {
return fmt.Errorf("insert failure reason: %w", err)
}
res, err := tx.ExecContext(ctx, `
DELETE FROM acquired WHERE task_id = ?
`, taskID)
if err != nil {
return fmt.Errorf("delete acquired lock: %w", err)
}
if n, err := res.RowsAffected(); err != nil {
return fmt.Errorf("rows affected: %w", err)
} else if n != 1 {
return fmt.Errorf("task %q is not acquired", taskID)
}
var retry uint
if err := tx.QueryRowContext(ctx, `SELECT retry FROM tasks WHERE task_id = ? LIMIT 1`, taskID).Scan(&retry); err != nil {
return fmt.Errorf("scan task retry: %w", err)
}
var failures uint
if err := tx.QueryRowContext(ctx, `SELECT count(*) FROM failures WHERE task_id = ?`, taskID).Scan(&failures); err != nil {
return fmt.Errorf("scan task failures count: %w", err)
}
if retry <= failures {
_, err := tx.ExecContext(ctx, `
INSERT INTO deadqueue (task_id, name, payload, created_at)
SELECT @task_id, name, payload, @created_at
FROM tasks WHERE task_id = @task_id
`,
sql.Named("task_id", taskID),
sql.Named("created_at", now.Unix()),
)
if err != nil {
return fmt.Errorf("move task to deadqueue: %w", err)
}
if _, err := tx.ExecContext(ctx, `DELETE FROM tasks WHERE task_id = ?`, taskID); err != nil {
return fmt.Errorf("delete from tasks list: %w", err)
}
} else {
// Delay execution of this task, so that it is not picked up
// again instantly
backoff := time.Duration(failures*failures) * time.Minute
_, err := tx.ExecContext(ctx, `
UPDATE tasks SET execute_at = ? WHERE task_id = ?
`, now.Add(backoff).Unix(), taskID)
if err != nil {
return fmt.Errorf("update task execution time: %w", err)
}
}
if err := tx.Commit(); err != nil {
return fmt.Errorf("commit transaction: %w", err)
}
return nil
}
func (s *Store) stats() (tasks, acquired, failures, deadqueue uint) {
err := s.db.QueryRow(`
SELECT
(SELECT COUNT(*) FROM tasks) AS tasks,
(SELECT COUNT(*) FROM acquired) AS acquired,
(SELECT COUNT(*) FROM deadqueue) AS deadqueue,
(SELECT COUNT(*) FROM failures) AS failures
`).Scan(&tasks, &acquired, &deadqueue, &failures)
if err != nil {
panic(err)
}
return
}
func (s *Store) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ctx := r.Context()
tx, err := s.db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
if err != nil {
http.Error(w, "Cannot start transaction.", http.StatusInternalServerError)
return
}
defer tx.Rollback()
type failure struct {
TaskID string
CreatedAt time.Time
Description string
}
type waitingtask struct {
TaskID string
Name string
Payload string
Retry int
Timeout time.Duration
ExecuteAt time.Time
CreatedAt time.Time
}
type processing struct {
TaskID string
Since time.Duration
}
var info struct {
WaitingCount uint
AcquiredCount uint
DeadqueueCount uint
FailuresCount uint
Waiting []waitingtask
Acquired []processing
Failures []failure
}
if err := tx.QueryRow(`
SELECT
(SELECT COUNT(*) FROM tasks) AS tasks,
(SELECT COUNT(*) FROM acquired) AS acquired,
(SELECT COUNT(*) FROM deadqueue) AS deadqueue,
(SELECT COUNT(*) FROM failures) AS failures
`).Scan(&info.WaitingCount, &info.AcquiredCount, &info.DeadqueueCount, &info.FailuresCount); err != nil {
http.Error(w, "Cannot select counters.", http.StatusInternalServerError)
return
}
failures, err := tx.QueryContext(ctx, `
SELECT task_id, created_at, description FROM failures ORDER BY created_at DESC LIMIT 20
`)
if err != nil {
http.Error(w, "Cannot query failed tasks.", http.StatusInternalServerError)
return
}
defer failures.Close()
for failures.Next() {
var f failure
var createdAt int64
if err := failures.Scan(&f.TaskID, &createdAt, &f.Description); err != nil {
http.Error(w, "Cannot scan failed task.", http.StatusInternalServerError)
return
}
f.CreatedAt = time.Unix(createdAt, 0)
info.Failures = append(info.Failures, f)
}
if err := failures.Err(); err != nil {
http.Error(w, "Cannot finish failure scanning.", http.StatusInternalServerError)
return
}
tasks, err := tx.QueryContext(ctx, `
SELECT task_id, name, payload, retry, timeout, execute_at, created_at
FROM tasks
ORDER BY execute_at DESC
LIMIT 10
`)
if err != nil {
http.Error(w, "Query waiting tasks.", http.StatusInternalServerError)
return
}
defer tasks.Close()
for tasks.Next() {
var t waitingtask
var timeout, executeAt, createdAt int64
if err := tasks.Scan(&t.TaskID, &t.Name, &t.Payload, &t.Retry, &timeout, &executeAt, &createdAt); err != nil {
http.Error(w, "Scan waiting task.", http.StatusInternalServerError)
return
}
t.Timeout = time.Duration(timeout) * time.Second
t.ExecuteAt = time.Unix(executeAt, 0)
t.CreatedAt = time.Unix(createdAt, 0)
info.Waiting = append(info.Waiting, t)
}
if err := tasks.Err(); err != nil {
http.Error(w, "Waiting tasks rows.", http.StatusInternalServerError)
return
}
acquired, err := tx.QueryContext(ctx, `SELECT task_id, created_at FROM acquired ORDER BY created_at DESC LIMIT 50`)
if err != nil {
http.Error(w, "Query acquired tasks.", http.StatusInternalServerError)
return
}
defer acquired.Close()
now := time.Now()
for acquired.Next() {
var p processing
var createdAt int64
if err := acquired.Scan(&p.TaskID, &createdAt); err != nil {
http.Error(w, "Scan acquired task ID.", http.StatusInternalServerError)
return
}
p.Since = now.Sub(time.Unix(createdAt, 0))
info.Acquired = append(info.Acquired, p)
}
if err := acquired.Err(); err != nil {
http.Error(w, "Acquired tasks rows.", http.StatusInternalServerError)
return
}
// Provide a nice, since page view on the queue state.
var b bytes.Buffer
if err := tmpl.Execute(&b, info); err != nil {
http.Error(w, "Cannot render response.", http.StatusInternalServerError)
return
}
w.Header().Add("content-type", "text/html")
w.WriteHeader(http.StatusOK)
_, _ = b.WriteTo(w)
}
var (
//go:embed store_info.html
tmplString string
tmpl = template.Must(template.New("").Parse(tmplString))
)
// currentTime is a variable so that it can be overwritten in tests.
var currentTime = func() time.Time {
return time.Now().UTC().Truncate(time.Second)
}
// generateID is a variable so that it can be overwritten in tests.
var generateID = func() string {
b := make([]byte, 16)
if _, err := rand.Read(b); err != nil {
panic(err)
}
return hex.EncodeToString(b)
}
var (
ErrEmpty = errors.New("empty")
ErrNotFound = errors.New("task not found")
ErrLocked = errors.New("task is locked")
)