Skip to content

Commit

Permalink
Merge branch 'dev'
Browse files Browse the repository at this point in the history
  • Loading branch information
lxzan committed Dec 15, 2023
2 parents f529662 + dcfbb25 commit c5ec189
Show file tree
Hide file tree
Showing 3 changed files with 58 additions and 41 deletions.
73 changes: 43 additions & 30 deletions groups/group.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
"github.com/lxzan/concurrency/internal"
"sync"
"sync/atomic"
"time"
)

Expand All @@ -13,48 +14,51 @@ const (
defaultWaitTimeout = 60 * time.Second // 默认线程同步等待超时
)

var defaultCaller Caller = func(args any, f func(any) error) error { return f(args) }

type (
Caller func(args any, f func(any) error) error

Group[T any] struct {
options *options
mu *sync.Mutex // 锁
errs []error // 错误
done chan bool // 信号
q []T // 任务队列
taskDone int64 // 已完成任务数量
taskTotal int64 // 总任务数量
OnMessage func(args T) error // 任务处理
OnError func(err error) // 错误处理
options *options // 配置
mu sync.Mutex // 锁
ctx context.Context // 上下文
cancelFunc context.CancelFunc // 取消函数
canceled atomic.Uint32 // 是否已取消
errs []error // 错误
done chan bool // 完成信号
q []T // 任务队列
taskDone int64 // 已完成任务数量
taskTotal int64 // 总任务数量
OnMessage func(args T) error // 任务处理
OnError func(err error) // 错误处理
}
)

// New 新建一个任务集
func New[T any](opts ...Option) *Group[T] {
o := &options{
timeout: defaultWaitTimeout,
concurrency: defaultConcurrency,
caller: func(args any, f func(any) error) error { return f(args) },
}
o := new(options)
opts = append(opts, withInitialize())
for _, f := range opts {
f(o)
}

c := &Group[T]{
options: o,
mu: &sync.Mutex{},
q: make([]T, 0),
taskDone: 0,
done: make(chan bool),
}
c.ctx, c.cancelFunc = context.WithTimeout(context.Background(), o.timeout)
c.OnMessage = func(args T) error {
return nil
}
c.OnError = func(err error) {}

return c
}

func (c *Group[T]) clear() {
func (c *Group[T]) clearJob() {
c.mu.Lock()
c.q = c.q[:0]
c.mu.Unlock()
Expand Down Expand Up @@ -82,19 +86,21 @@ func (c *Group[T]) incrAndIsDone() bool {
return ok
}

func (c *Group[T]) hasError() bool {
func (c *Group[T]) getError() error {
c.mu.Lock()
defer c.mu.Unlock()
return len(c.errs) > 0
return errors.Join(c.errs...)
}

func (c *Group[T]) jobFunc(v any) error {
if c.canceled.Load() == 1 {
return nil
}
return c.OnMessage(v.(T))
}

func (c *Group[T]) do(args T) {
if err := c.options.caller(args, func(v any) error {
if c.options.cancel && c.hasError() {
return nil
}
return c.OnMessage(v.(T))
}); err != nil {
if err := c.options.caller(args, c.jobFunc); err != nil {
c.mu.Lock()
c.errs = append(c.errs, err)
c.mu.Unlock()
Expand All @@ -119,6 +125,13 @@ func (c *Group[T]) Len() int {
return x
}

// Cancel 取消队列中剩余任务的执行
func (c *Group[T]) Cancel() {
if c.canceled.CompareAndSwap(0, 1) {
c.cancelFunc()
}
}

// Push 往任务队列中追加任务
func (c *Group[T]) Push(eles ...T) {
c.mu.Lock()
Expand Down Expand Up @@ -148,13 +161,13 @@ func (c *Group[T]) Start() error {
}
}

ctx, cancel := context.WithTimeout(context.Background(), c.options.timeout)
defer cancel()
defer c.cancelFunc()

select {
case <-c.done:
return errors.Join(c.errs...)
case <-ctx.Done():
c.clear()
return ctx.Err()
return c.getError()
case <-c.ctx.Done():
c.clearJob()
return c.ctx.Err()
}
}
5 changes: 4 additions & 1 deletion groups/group_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -82,7 +82,7 @@ func TestNewTaskGroup(t *testing.T) {
})

t.Run("cancel", func(t *testing.T) {
ctl := New[int](WithCancel(), WithConcurrency(1))
ctl := New[int](WithConcurrency(1))
ctl.Push(1, 3, 5)
arr := make([]int, 0)
ctl.OnMessage = func(args int) error {
Expand All @@ -96,6 +96,9 @@ func TestNewTaskGroup(t *testing.T) {
return nil
}
}
ctl.OnError = func(err error) {
ctl.Cancel()
}
err := ctl.Start()
as.Error(err)
as.ElementsMatch(arr, []int{1, 3})
Expand Down
21 changes: 11 additions & 10 deletions groups/options.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package groups

import (
"github.com/lxzan/concurrency/internal"
"github.com/pkg/errors"
"runtime"
"time"
Expand All @@ -11,7 +12,6 @@ type options struct {
timeout time.Duration
concurrency int64
caller Caller
cancel bool
}

type Option func(o *options)
Expand All @@ -24,16 +24,9 @@ func WithTimeout(t time.Duration) Option {
}

// WithConcurrency 设置最大并发
func WithConcurrency(n int64) Option {
func WithConcurrency(n uint32) Option {
return func(o *options) {
o.concurrency = n
}
}

// WithCancel 设置遇到错误放弃执行剩余任务
func WithCancel() Option {
return func(o *options) {
o.cancel = true
o.concurrency = int64(n)
}
}

Expand All @@ -55,3 +48,11 @@ func WithRecovery() Option {
}
}
}

func withInitialize() Option {
return func(o *options) {
o.timeout = internal.SelectValue(o.timeout <= 0, defaultWaitTimeout, o.timeout)
o.concurrency = internal.SelectValue(o.concurrency <= 0, defaultConcurrency, o.concurrency)
o.caller = internal.SelectValue(o.caller == nil, defaultCaller, o.caller)
}
}

0 comments on commit c5ec189

Please sign in to comment.