From 0b47349abb26c03bd83354a1f0e6fc323a5109c0 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Thu, 25 Aug 2022 13:59:52 +0800 Subject: [PATCH 01/15] =?UTF-8?q?pool:=20TaskPool=20=E5=AE=9E=E7=8E=B0=20#?= =?UTF-8?q?50?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit TaskPool管理有限状态机,TaskExecutor负责启动任务,优雅关闭,强制退出等操作 Signed-off-by: longyue0521 --- .CHANGELOG.md | 1 + pool/task_pool.go | 387 ++++++++++++++++++++- pool/task_pool_test.go | 769 +++++++++++++++++++++++++++++++++++++++++ 3 files changed, 1143 insertions(+), 14 deletions(-) create mode 100644 pool/task_pool_test.go diff --git a/.CHANGELOG.md b/.CHANGELOG.md index f8460724..adf78e2c 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -10,3 +10,4 @@ [ekit: 实现了 LinkedList Delete](https://github.com/gotomicro/ekit/pull/38) [ekit: 修复 Pool TestPool](https://github.com/gotomicro/ekit/pull/40) [ekit: 实现了 LinkedList Range](https://github.com/gotomicro/ekit/pull/46) +[ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/issues/50) \ No newline at end of file diff --git a/pool/task_pool.go b/pool/task_pool.go index 39d2a782..c8b07687 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -14,7 +14,30 @@ package pool -import "context" +import ( + "context" + "errors" + "fmt" + "reflect" + "runtime" + "sync/atomic" + "time" +) + +var ( + StateCreated int32 = 1 + StateRunning int32 = 2 + StateClosing int32 = 3 + StateStopped int32 = 4 + + ErrTaskPoolIsNotRunning = errors.New("pool: TaskPool未运行") + ErrTaskPoolIsClosing = errors.New("pool:TaskPool关闭中") + ErrTaskPoolIsStopped = errors.New("pool: TaskPool已停止") + ErrTaskIsInvalid = errors.New("pool: Task非法") + + _ TaskPool = &BlockQueueTaskPool{} + panicBuffLen = 2048 +) // TaskPool 任务池 type TaskPool interface { @@ -52,35 +75,151 @@ type TaskFunc func(ctx context.Context) error // Run 执行任务 // 超时控制取决于衍生出 TaskFunc 的方法 -func (t TaskFunc) Run(ctx context.Context) error { - return t(ctx) -} +func (t TaskFunc) Run(ctx context.Context) error { return t(ctx) } + +// FastTask 快任务,耗时较短,每个任务一个Goroutine +type FastTask struct{ task Task } + +func (f *FastTask) Run(ctx context.Context) error { return f.task.Run(ctx) } + +// SlowTask 慢任务,耗时较长,运行在固定个数Goroutine上 +type SlowTask struct{ task Task } + +func (s *SlowTask) Run(ctx context.Context) error { return s.task.Run(ctx) } // BlockQueueTaskPool 并发阻塞的任务池 type BlockQueueTaskPool struct { + state atomic.Int32 + numGo int + lenQu int + + taskExecutor *TaskExecutor + fastTaskQueue chan<- Task + slowTaskQueue chan<- Task + + // 缓存taskExecutor结果 + done <-chan struct{} + tasks []Task + + locked int32 } // NewBlockQueueTaskPool 创建一个新的 BlockQueueTaskPool // concurrency 是并发数,即最多允许多少个 goroutine 执行任务 // queueSize 是队列大小,即最多有多少个任务在等待调度 func NewBlockQueueTaskPool(concurrency int, queueSize int) (*BlockQueueTaskPool, error) { - return &BlockQueueTaskPool{}, nil + taskExecutor := NewTaskExecutor(concurrency, queueSize) + b := &BlockQueueTaskPool{ + numGo: concurrency, + lenQu: queueSize, + done: make(chan struct{}), + taskExecutor: taskExecutor, + fastTaskQueue: taskExecutor.FastQueue(), + slowTaskQueue: taskExecutor.SlowQueue(), + locked: int32(101), + } + b.state.Store(StateCreated) + return b, nil +} + +func (b *BlockQueueTaskPool) State() int32 { + for { + state := b.state.Load() + if state != b.locked { + return state + } + } } // Submit 提交一个任务 // 如果此时队列已满,那么将会阻塞调用者。 // 如果因为 ctx 的原因返回,那么将会返回 ctx.Err() // 在调用 Start 前后都可以调用 Submit -func (b *BlockQueueTaskPool) Submit(ctx context.Context, task func()) error { - // TODO implement me - panic("implement me") +func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error { + if task == nil || reflect.ValueOf(task).IsNil() { + return fmt.Errorf("%w", ErrTaskIsInvalid) + } + // todo: 用户未设置超时,可以考虑内部给个超时提交 + for { + + if b.state.Load() == StateClosing { + return fmt.Errorf("%w", ErrTaskPoolIsClosing) + } + + if b.state.Load() == StateStopped { + return fmt.Errorf("%w", ErrTaskPoolIsStopped) + } + + if b.state.CompareAndSwap(StateCreated, b.locked) { + ok, err := b.submit(ctx, task, func() chan<- Task { return b.chanByTask(task) }) + if ok || err != nil { + b.state.Swap(StateCreated) + return err + } + b.state.Swap(StateCreated) + } + + if b.state.CompareAndSwap(StateRunning, b.locked) { + ok, err := b.submit(ctx, task, func() chan<- Task { return b.chanByTask(task) }) + if ok || err != nil { + b.state.Swap(StateRunning) + return err + } + b.state.Swap(StateRunning) + } + } +} + +func (b *BlockQueueTaskPool) chanByTask(task Task) chan<- Task { + switch task.(type) { + case *SlowTask: + return b.slowTaskQueue + default: + // FastTask, TaskFunc, 用户自定义类型实现Task接口 + return b.fastTaskQueue + } +} + +func (b *BlockQueueTaskPool) submit(ctx context.Context, task Task, channel func() chan<- Task) (ok bool, err error) { + // 此处channel() <- task不会出现panic——因为channel被关闭而panic + // 代码执行到submit时TaskPool处于lock状态 + // 要关闭channel需要TaskPool处于RUNNING状态,Shutdown/ShutdownNow才能成功 + select { + case <-ctx.Done(): + return false, fmt.Errorf("%w", ctx.Err()) + case channel() <- task: + return true, nil + default: + } + return false, nil } // Start 开始调度任务执行 // Start 之后,调用者可以继续使用 Submit 提交任务 func (b *BlockQueueTaskPool) Start() error { - // TODO implement me - panic("implement me") + + for { + + if b.state.Load() == StateClosing { + return fmt.Errorf("%w", ErrTaskPoolIsClosing) + } + + if b.state.Load() == StateStopped { + return fmt.Errorf("%w", ErrTaskPoolIsStopped) + } + + if b.state.Load() == StateRunning { + // 重复调用,返回缓存结果 + return nil + } + + if b.state.CompareAndSwap(StateCreated, StateRunning) { + // todo: 启动task调度器,开始执行task + b.taskExecutor.Start() + return nil + } + } + //return nil } // Shutdown 将会拒绝提交新的任务,但是会继续执行已提交任务 @@ -88,12 +227,232 @@ func (b *BlockQueueTaskPool) Start() error { // Shutdown 会负责关闭返回的 chan // Shutdown 无法中断正在执行的任务 func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { - // TODO implement me - panic("implement me") + + for { + + if b.state.Load() == StateCreated { + return nil, fmt.Errorf("%w", ErrTaskPoolIsNotRunning) + } + + if b.state.Load() == StateStopped { + // 重复调用时,恰好前一个Shutdown调用将状态迁移为StateStopped + // 这种情况与先调用ShutdownNow状态迁移为StateStopped再调用Shutdown效果一样 + return nil, fmt.Errorf("%w", ErrTaskPoolIsStopped) + } + + if b.state.Load() == StateClosing { + // 重复调用,返回缓存结果 + return b.done, nil + } + + if b.state.CompareAndSwap(StateRunning, StateClosing) { + // todo: 等待task完成,关闭b.done + // 监听done信号,然后完成状态迁移StateClosing -> StateStopped + b.done = b.taskExecutor.Close() // 需要注释掉close(b.done) + //close(b.done) + go func() { + select { + case <-b.done: + b.state.CompareAndSwap(StateClosing, StateStopped) + //ok := b.state.CompareAndSwap(StateClosing, StateStopped) + //for !ok { + // ok = b.state.CompareAndSwap(StateClosing, StateStopped) + //} + //return + } + }() + return b.done, nil + } + + } + //return nil, nil } // ShutdownNow 立刻关闭任务池,并且返回所有剩余未执行的任务(不包含正在执行的任务) func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { - // TODO implement me - panic("implement me") + + for { + + if b.state.Load() == StateCreated { + return nil, fmt.Errorf("%w", ErrTaskPoolIsNotRunning) + } + + if b.state.Load() == StateClosing { + return nil, fmt.Errorf("%w", ErrTaskPoolIsClosing) + } + + if b.state.Load() == StateStopped { + // 重复调用,返回缓存结果 + return b.tasks, nil + } + if b.state.CompareAndSwap(StateRunning, StateStopped) { + b.tasks = b.taskExecutor.Stop() + return b.tasks, nil + } + } +} + +type TaskExecutor struct { + slowTasks chan Task + fastTasks chan Task + + maxGo int32 + // + done chan struct{} + // 利用ctx充当内部信号 + ctx context.Context + cancelFunc context.CancelFunc + // 统计 + numSlow atomic.Int32 + numFast atomic.Int32 +} + +func NewTaskExecutor(maxGo int, queueSize int) *TaskExecutor { + t := &TaskExecutor{maxGo: int32(maxGo), done: make(chan struct{})} + t.ctx, t.cancelFunc = context.WithCancel(context.Background()) + t.slowTasks = make(chan Task, queueSize) + t.fastTasks = make(chan Task, queueSize) + return t +} + +func (t *TaskExecutor) Start() { + go t.startSlowTasks() + go t.startFastTasks() +} + +func (t *TaskExecutor) startFastTasks() { + for { + select { + case <-t.ctx.Done(): + return + case task := <-t.fastTasks: + // handle close(t.fastTasks) + if task == nil { + return + } + go func() { + t.numFast.Add(1) + //log.Println("fast N", t.numFast.Add(1)) + defer func() { + // 恢复统计 + t.numFast.Add(-1) + + // handle panic + if r := recover(); r != nil { + buf := make([]byte, panicBuffLen) + buf = buf[:runtime.Stack(buf, false)] + //fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) + } + }() + // todo: handle err + task.Run(t.ctx) + }() + } + } +} + +func (t *TaskExecutor) startSlowTasks() { + + for { + n := atomic.AddInt32(&t.maxGo, -1) + if n < 0 { + atomic.AddInt32(&t.maxGo, 1) + continue + } + //log.Println("maxGo=", n) + select { + case <-t.ctx.Done(): + return + case task := <-t.slowTasks: + // handle close(t.slowTasks) + if task == nil { + return + } + go func() { + t.numSlow.Add(1) + //log.Println("slow N=", t.numSlow.Add(1)) + defer func() { + // 恢复 + atomic.AddInt32(&t.maxGo, 1) + t.numSlow.Add(-1) + + // handle panic + if r := recover(); r != nil { + buf := make([]byte, panicBuffLen) + buf = buf[:runtime.Stack(buf, false)] + //fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) + } + }() + // todo: handle err + task.Run(t.ctx) + }() + } + } +} + +func (t *TaskExecutor) FastQueue() chan<- Task { + return t.fastTasks +} + +func (t *TaskExecutor) SlowQueue() chan<- Task { + return t.slowTasks +} + +func (t *TaskExecutor) NumRunningSlow() int32 { + return t.numSlow.Load() +} + +func (t *TaskExecutor) NumRunningFast() int32 { + return t.numFast.Load() +} + +// Close 优雅关闭 +// 目标:不但希望正在运行中的任务自然退出,还希望队列中等待的任务也能启动执行并自然退出 +// 策略:先将所有队列中的任务启动并执行(清空队列),再等待全部运行中的任务自然退出。 +func (t *TaskExecutor) Close() <-chan struct{} { + + // 先关闭等待队列不再允许提交 + // 同时任务启动循环能够通过Task==nil来终止循环 + close(t.slowTasks) + close(t.fastTasks) + + go func() { + + // 检查三次是因为可能出现: + // 两队列中有任务且正在创建启动任务尚未执行计数,恰巧此时正在运行中的任务为0 + for i := 0; i < 3; i++ { + + // 确保所有运行中任务也自然退出 + for t.numFast.Load() != 0 || t.numSlow.Load() != 0 { + time.Sleep(time.Second) + } + } + + // 通知外部调用者 + close(t.done) + }() + + return t.done +} + +// Stop 强制关闭 +// 目标:立刻关闭并且返回所有剩下未执行的任务 +// 策略:关闭等待队列不再接受新任务,中断任务启动循环,清空等待队列并保存返回 +func (t *TaskExecutor) Stop() []Task { + + close(t.fastTasks) + close(t.slowTasks) + + // 发送中断信号,中断任务启动循环 + t.cancelFunc() + + // 清空队列并保存 + var tasks []Task + for task := range t.fastTasks { + tasks = append(tasks, task) + } + for task := range t.slowTasks { + tasks = append(tasks, task) + } + return tasks } diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go new file mode 100644 index 00000000..bfcfffbf --- /dev/null +++ b/pool/task_pool_test.go @@ -0,0 +1,769 @@ +// Copyright 2021 gotomicro +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package pool + +import ( + "context" + "fmt" + "github.com/stretchr/testify/assert" + "sync" + "testing" + "time" +) + +/* +TaskPool有限状态机 + Start/Submit/ShutdownNow() Error + \ / + Shutdown() --> CLOSING ---等待所有任务结束 + Submit()nil--执行中状态迁移--Submit() / \----------/ \----------/ + \ / \ / / +New() --> CREATED -- Start() ---> RUNNING -- -- + \ / \ / \ Start/Submit/Shutdown() Error + Shutdown/ShutdownNow()Error Start() \ \ / + ShutdownNow() ---> STOPPED -- ShutdownNow() --> STOPPED +*/ + +func TestTaskPool_Start(t *testing.T) { + /* + todo: Start + 1. happy - [x] change state from CREATED to RUNNING - done + - 非阻塞 task 调度器开始工作,开始执行工作 + - [x] Start多次,保证只运行一次,或者报错——TaskPool已经启动 + 2. sad - + CLOSING state -> start error,多次运行结果一致 + STOPPED state -> start error多次运行结果一致 + */ + + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + errChan := make(chan error) + go func() { + // 多次运行结果一直 + var wg sync.WaitGroup + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + errChan <- pool.Start() + wg.Done() + }() + } + wg.Wait() + close(errChan) + }() + for err := range errChan { + assert.NoError(t, err) + } + assert.Equal(t, StateRunning, pool.State()) +} + +func TestTaskPool_Submit(t *testing.T) { + + //todo: Submit + // TaskPoolRunning Shutdown/ShutdownNow前 + // 在TaskPool所有状态中都可以提交,有的成功/阻塞,有的立即失败。 + // [x]ctx时间内提交成功 + // [x]ctx超时,提交失败给出错误信息 + // [x] 监听状态变化,从running->closing/stopped + // [x] Shutdown后,状态变迁,需要检查出并报错 ErrTaskPoolIsClosing + // [x] ShutdownNow后状态变迁,需要检查出并报错,ErrTaskPoolIsStopped + + t.Run("提交Task阻塞", func(t *testing.T) { + + t.Run("TaskPool状态由Created变为Running", func(t *testing.T) { + t.Parallel() + // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + errChan := make(chan error) + go func() { + errChan <- pool.Start() + }() + testSubmitBlockingAndTimeout(t, pool) + assert.NoError(t, <-errChan) + assert.Equal(t, StateRunning, pool.State()) + }) + + t.Run("TaskPool状态由Running变为Closing", func(t *testing.T) { + t.Parallel() + // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 2) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.NoError(t, err) + assert.Equal(t, StateRunning, pool.State()) + + // 模拟阻塞提交 + n := 10 + submitErrChan := make(chan error, 1) + for i := 0; i < n; i++ { + go func() { + err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(10 * time.Millisecond) + return nil + })}) + if err != nil { + submitErrChan <- err + } + }() + } + + // 调用Shutdown使TaskPool状态发生迁移 + type Result struct { + done <-chan struct{} + err error + } + resultChan := make(chan Result) + go func() { + <-time.After(time.Millisecond) + done, err := pool.Shutdown() + resultChan <- Result{done: done, err: err} + }() + + r := <-resultChan + + // 阻塞中的任务报错,证明处于TaskPool处于StateClosing状态 + assert.ErrorIs(t, <-submitErrChan, ErrTaskPoolIsClosing) + + // Shutdown调用成功 + assert.NoError(t, r.err) + select { + case <-r.done: + // 等待状态迁移完成,并最终进入StateStopped状态 + <-time.After(10 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) + return + } + }) + + t.Run("TaskPool状态由Running变为Stopped", func(t *testing.T) { + t.Parallel() + // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 2) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.NoError(t, err) + assert.Equal(t, StateRunning, pool.State()) + + // 模拟阻塞提交 + n := 5 + submitErrChan := make(chan error, 1) + for i := 0; i < n; i++ { + go func() { + err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(10 * time.Millisecond) + return nil + })}) + if err != nil { + submitErrChan <- err + } + }() + } + + // 并发调用ShutdownNow + type Result struct { + tasks []Task + err error + } + result := make(chan Result, 1) + go func() { + <-time.After(time.Millisecond) + tasks, err := pool.ShutdownNow() + result <- Result{tasks: tasks, err: err} + }() + + r := <-result + assert.NoError(t, r.err) + assert.NotNil(t, r.tasks) + + assert.ErrorIs(t, <-submitErrChan, ErrTaskPoolIsStopped) + assert.Equal(t, StateStopped, pool.State()) + }) + + }) + +} + +func TestTaskPool_Shutdown(t *testing.T) { + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + // 第一次调用 + done, err := pool.Shutdown() + assert.NoError(t, err) + + // 第二次调用 + select { + case <-done: + break + default: + done2, err2 := pool.Shutdown() + assert.Equal(t, done2, done) + assert.Equal(t, err2, err) + } + + <-time.After(5 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) + + // 第一个Shutdown将状态迁移至StateStopped + // 第三次调用 + done, err = pool.Shutdown() + assert.Nil(t, done) + assert.ErrorIs(t, err, ErrTaskPoolIsStopped) +} + +func TestTestPool_ShutdownNow(t *testing.T) { + + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + type result struct { + tasks []Task + err error + } + n := 3 + c := make(chan result, n) + + for i := 0; i < n; i++ { + go func() { + + tasks, err := pool.ShutdownNow() + c <- result{tasks: tasks, err: err} + }() + } + + for i := 0; i < n; i++ { + r := <-c + assert.Nil(t, r.tasks) + assert.NoError(t, r.err) + assert.Equal(t, StateStopped, pool.State()) + } +} + +func TestTaskPool__Created_(t *testing.T) { + t.Parallel() + + n, q := 1, 1 + pool, err := NewBlockQueueTaskPool(n, q) + assert.NoError(t, err) + assert.NotNil(t, pool) + assert.Equal(t, StateCreated, pool.State()) + + t.Run("Submit", func(t *testing.T) { + t.Parallel() + + t.Run("提交非法Task", func(t *testing.T) { + t.Parallel() + testSubmitInvalidTask(t, pool) + assert.Equal(t, StateCreated, pool.State()) + }) + + t.Run("正常提交Task", func(t *testing.T) { + t.Parallel() + testSubmitValidTask(t, pool) + assert.Equal(t, StateCreated, pool.State()) + }) + + t.Run("阻塞提交并导致超时", func(t *testing.T) { + t.Parallel() + // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + testSubmitBlockingAndTimeout(t, pool) + assert.Equal(t, StateCreated, pool.State()) + }) + }) + + t.Run("Shutdown", func(t *testing.T) { + t.Parallel() + + done, err := pool.Shutdown() + assert.Nil(t, done) + assert.ErrorIs(t, err, ErrTaskPoolIsNotRunning) + assert.Equal(t, StateCreated, pool.State()) + }) + + t.Run("ShutdownNow", func(t *testing.T) { + t.Parallel() + + tasks, err := pool.ShutdownNow() + assert.Nil(t, tasks) + assert.ErrorIs(t, err, ErrTaskPoolIsNotRunning) + assert.Equal(t, StateCreated, pool.State()) + }) + +} + +func TestTaskPool__Running_(t *testing.T) { + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + t.Run("Start", func(t *testing.T) { + t.Parallel() + err = pool.Start() + // todo: 调度器只启动一次 + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + }) + + t.Run("Submit", func(t *testing.T) { + t.Parallel() + + t.Run("提交非法Task", func(t *testing.T) { + t.Parallel() + testSubmitInvalidTask(t, pool) + assert.Equal(t, StateRunning, pool.State()) + }) + + t.Run("正常提交Task", func(t *testing.T) { + t.Parallel() + testSubmitValidTask(t, pool) + assert.Equal(t, StateRunning, pool.State()) + }) + + t.Run("阻塞提交并导致超时", func(t *testing.T) { + t.Parallel() + // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + testSubmitBlockingAndTimeout(t, pool) + + assert.Equal(t, StateRunning, pool.State()) + }) + }) +} + +func TestTaskPool__Closing_(t *testing.T) { + + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 0) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + // 提交任务 + num := 10 + for i := 0; i < num; i++ { + go func() { + pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(10 * time.Millisecond) + return nil + })}) + }() + } + + t.Run("Start", func(t *testing.T) { + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 10) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + // 提交任务 + num := 10 + for i := 0; i < num; i++ { + go func() { + pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(10 * time.Millisecond) + return nil + })}) + }() + } + + done, err := pool.Shutdown() + assert.NoError(t, err) + assert.ErrorIs(t, pool.Start(), ErrTaskPoolIsClosing) + select { + case <-done: + <-time.After(5 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) + return + } + }) + + t.Run("ShutdownNow", func(t *testing.T) { + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 0) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + // 提交任务 + num := 10 + for i := 0; i < num; i++ { + go func() { + pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(10 * time.Millisecond) + return nil + })}) + }() + } + + done, err := pool.Shutdown() + assert.NoError(t, err) + + tasks, err := pool.ShutdownNow() + assert.ErrorIs(t, err, ErrTaskPoolIsClosing) + assert.Nil(t, tasks) + + select { + case <-done: + <-time.After(50 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) + return + } + }) + +} + +func TestTestPool__Stopped_(t *testing.T) { + + n := 2 + pool, _ := NewBlockQueueTaskPool(1, n) + assert.Equal(t, StateCreated, pool.State()) + err := pool.Start() + assert.Equal(t, StateRunning, pool.State()) + assert.NoError(t, err) + + // 模拟阻塞提交 + for i := 0; i < 3*n; i++ { + go func() { + pool.Submit(context.Background(), &FastTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(2 * time.Millisecond) + return nil + })}) + pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(10 * time.Millisecond) + return nil + })}) + }() + } + + tasks, err := pool.ShutdownNow() + //assert.NotNil(t, tasks) + assert.NoError(t, err) + assert.Equal(t, StateStopped, pool.State()) + + t.Run("Start", func(t *testing.T) { + t.Parallel() + assert.ErrorIs(t, pool.Start(), ErrTaskPoolIsStopped) + assert.Equal(t, StateStopped, pool.State()) + }) + + t.Run("Submit", func(t *testing.T) { + t.Parallel() + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { return nil })) + assert.ErrorIs(t, err, ErrTaskPoolIsStopped) + assert.Equal(t, StateStopped, pool.State()) + }) + + t.Run("Shutdown", func(t *testing.T) { + t.Parallel() + + done, err := pool.Shutdown() + assert.Nil(t, done) + assert.ErrorIs(t, err, ErrTaskPoolIsStopped) + assert.Equal(t, StateStopped, pool.State()) + }) + + t.Run("ShutdownNow", func(t *testing.T) { + t.Parallel() + // 多次调用返回相同结果 + tasks2, err := pool.ShutdownNow() + assert.NoError(t, err) + assert.Equal(t, tasks2, tasks) + }) +} + +func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { + err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(2 * time.Millisecond) + return nil + })}) + assert.NoError(t, err) + + n := 2 + errChan := make(chan error, n) + for i := 0; i < n; i++ { + go func() { + ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) + defer cancel() + err := pool.Submit(ctx, &SlowTask{task: TaskFunc(func(ctx context.Context) error { + <-time.After(2 * time.Millisecond) + return nil + })}) + if err != nil { + errChan <- err + } + }() + } + assert.ErrorIs(t, <-errChan, context.DeadlineExceeded) +} + +func testSubmitValidTask(t *testing.T, pool *BlockQueueTaskPool) { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { return nil })) + assert.NoError(t, err) +} + +func testSubmitInvalidTask(t *testing.T, pool *BlockQueueTaskPool) { + invalidTasks := map[string]Task{"*SlowTask": (*SlowTask)(nil), "*FastTask": (*FastTask)(nil), "nil": nil, "TaskFunc(nil)": TaskFunc(nil)} + + for name, task := range invalidTasks { + t.Run(name, func(t *testing.T) { + err := pool.Submit(context.Background(), task) + assert.ErrorIs(t, err, ErrTaskIsInvalid) + }) + } +} + +func TestTaskExecutor(t *testing.T) { + t.Parallel() + + /* todo: + [x]快慢任务分离 + [x]快任务没有goroutine限制,提供方法检查个数 + [x]慢任务占用固定个数goroutine,提供方法检查个数 + [x]任务的panic处理 + []任务的error处理 + todo: [x]Closing-优雅关闭 + [x]返回一个chan, + 供调用者监听,调用者从chan拿到消息即表明所有任务结束 + [x]等待任务自然结束 + [x]关闭chan——自动终止启动循环 + [x]将队列中等待的任务启动执行 + [x]等待未完成任务; + [x]Stop-强制关闭 + [x] 关闭chan + [x] 终止任务启动循环 + [x] 将队列中清空并未完成的任务返回 + */ + + t.Run("优雅关闭", func(t *testing.T) { + + t.Parallel() + + maxGo, numTasks := 2, 5 + n := 4 * numTasks + + ex := NewTaskExecutor(maxGo, n) + ex.Start() + + // 注意:添加Task后需要调整否者会阻塞 + resultChan := make(chan struct{}, n) + + go func() { + // chan may be closed + defer func() { + if r := recover(); r != nil { + // 发送失败,也算执行了 + resultChan <- struct{}{} + t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) + } + }() + for i := 0; i < numTasks; i++ { + ex.SlowQueue() <- &SlowTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + <-time.After(5 * time.Millisecond) + return nil + })} + // panic slow task + ex.SlowQueue() <- &SlowTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + panic("slow task ") + return nil + })} + } + }() + go func() { + // chan() may be closed + defer func() { + if r := recover(); r != nil { + // 发送失败,也算执行了 + resultChan <- struct{}{} + t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) + } + }() + for i := 0; i < numTasks; i++ { + ex.FastQueue() <- &FastTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + <-time.After(2 * time.Millisecond) + return nil + })} + + // panic fast task + ex.FastQueue() <- &FastTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + panic("fast task") + return nil + })} + } + }() + + // 等待任务开始执行 + <-time.After(100 * time.Millisecond) + + select { + case <-ex.Close(): + assert.Equal(t, n, 4*numTasks) + assert.Equal(t, int32(0), ex.NumRunningSlow()) + assert.Equal(t, int32(0), ex.NumRunningFast()) + close(resultChan) + num := 0 + for r := range resultChan { + if r == struct{}{} { + num++ + } + } + assert.Equal(t, n, num) + return + } + }) + + t.Run("强制关闭", func(t *testing.T) { + t.Parallel() + + maxGo, numTasks := 2, 5 + ex := NewTaskExecutor(maxGo, numTasks) + ex.Start() + + // 注意:确保n = len(slowTasks) + len(fastTasks) + n := 8 + resultChan := make(chan struct{}, n) + + slowTasks := []Task{ + &SlowTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + <-time.After(5 * time.Millisecond) + return nil + })}, + // panic slow task + &SlowTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + panic("slow task ") + return nil + })}, + &SlowTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + <-time.After(5 * time.Millisecond) + return nil + })}, + // panic slow task + &SlowTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + panic("slow task ") + return nil + })}, + } + + fastTasks := []Task{ + &FastTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + <-time.After(2 * time.Millisecond) + return nil + })}, + &FastTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + panic("fast task") + return nil + })}, + &FastTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + <-time.After(2 * time.Millisecond) + return nil + })}, + &FastTask{task: TaskFunc(func(ctx context.Context) error { + resultChan <- struct{}{} + panic("fast task") + return nil + })}, + } + go func() { + // chan may be closed + defer func() { + if r := recover(); r != nil { + // 发送任务时,chan被关闭,也算作执行中 + resultChan <- struct{}{} + } + }() + for _, task := range slowTasks { + ex.SlowQueue() <- task + } + }() + go func() { + // chan may be closed + defer func() { + if r := recover(); r != nil { + // 发送任务时,chan被关闭,也算作执行中 + resultChan <- struct{}{} + //t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) + } + }() + for _, task := range fastTasks { + ex.FastQueue() <- task + } + }() + + // 等待任务开始执行 + <-time.After(100 * time.Millisecond) + + tasks := ex.Stop() + + // 等待任务执行并回传信号 + <-time.After(100 * time.Millisecond) + + // 统计执行的任务数 + for ex.NumRunningFast() != 0 || ex.NumRunningSlow() != 0 { + time.Sleep(time.Millisecond) + } + close(resultChan) + num := 0 + for r := range resultChan { + if r == struct{}{} { + num++ + } + } + + assert.Equal(t, n, len(slowTasks)+len(fastTasks)) + assert.Equal(t, len(slowTasks)+len(fastTasks), num+len(tasks)) + }) + +} From 668754884770a36c2ed24c34ec4b00a3b7039660 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Sat, 27 Aug 2022 14:58:06 +0800 Subject: [PATCH 02/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8DShutdownNow=20data=20ra?= =?UTF-8?q?ce=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- pool/task_pool.go | 41 +++++++++-------- pool/task_pool_test.go | 102 +++++++++++++++++++---------------------- 2 files changed, 69 insertions(+), 74 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index c8b07687..e927fca7 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -20,6 +20,7 @@ import ( "fmt" "reflect" "runtime" + "sync" "sync/atomic" "time" ) @@ -98,9 +99,9 @@ type BlockQueueTaskPool struct { slowTaskQueue chan<- Task // 缓存taskExecutor结果 - done <-chan struct{} - tasks []Task - + done <-chan struct{} + tasks []Task + mux sync.RWMutex locked int32 } @@ -251,15 +252,13 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { b.done = b.taskExecutor.Close() // 需要注释掉close(b.done) //close(b.done) go func() { - select { - case <-b.done: - b.state.CompareAndSwap(StateClosing, StateStopped) - //ok := b.state.CompareAndSwap(StateClosing, StateStopped) - //for !ok { - // ok = b.state.CompareAndSwap(StateClosing, StateStopped) - //} - //return - } + <-b.done + b.state.CompareAndSwap(StateClosing, StateStopped) + //ok := b.state.CompareAndSwap(StateClosing, StateStopped) + //for !ok { + // ok = b.state.CompareAndSwap(StateClosing, StateStopped) + //} + //return }() return b.done, nil } @@ -283,11 +282,17 @@ func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { if b.state.Load() == StateStopped { // 重复调用,返回缓存结果 - return b.tasks, nil + b.mux.RLock() + tasks := append([]Task(nil), b.tasks...) + b.mux.RUnlock() + return tasks, nil } if b.state.CompareAndSwap(StateRunning, StateStopped) { + b.mux.Lock() b.tasks = b.taskExecutor.Stop() - return b.tasks, nil + tasks := append([]Task(nil), b.tasks...) + b.mux.Unlock() + return tasks, nil } } } @@ -341,11 +346,11 @@ func (t *TaskExecutor) startFastTasks() { if r := recover(); r != nil { buf := make([]byte, panicBuffLen) buf = buf[:runtime.Stack(buf, false)] - //fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) + fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) } }() // todo: handle err - task.Run(t.ctx) + fmt.Println(task.Run(t.ctx)) }() } } @@ -380,11 +385,11 @@ func (t *TaskExecutor) startSlowTasks() { if r := recover(); r != nil { buf := make([]byte, panicBuffLen) buf = buf[:runtime.Stack(buf, false)] - //fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) + fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) } }() // todo: handle err - task.Run(t.ctx) + fmt.Println(task.Run(t.ctx)) }() } } diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index bfcfffbf..127678cf 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -17,10 +17,11 @@ package pool import ( "context" "fmt" - "github.com/stretchr/testify/assert" "sync" "testing" "time" + + "github.com/stretchr/testify/assert" ) /* @@ -72,7 +73,7 @@ func TestTaskPool_Start(t *testing.T) { } func TestTaskPool_Submit(t *testing.T) { - + t.Parallel() //todo: Submit // TaskPoolRunning Shutdown/ShutdownNow前 // 在TaskPool所有状态中都可以提交,有的成功/阻塞,有的立即失败。 @@ -83,6 +84,7 @@ func TestTaskPool_Submit(t *testing.T) { // [x] ShutdownNow后状态变迁,需要检查出并报错,ErrTaskPoolIsStopped t.Run("提交Task阻塞", func(t *testing.T) { + t.Parallel() t.Run("TaskPool状态由Created变为Running", func(t *testing.T) { t.Parallel() @@ -141,13 +143,11 @@ func TestTaskPool_Submit(t *testing.T) { // Shutdown调用成功 assert.NoError(t, r.err) - select { - case <-r.done: - // 等待状态迁移完成,并最终进入StateStopped状态 - <-time.After(10 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) - return - } + + <-r.done + // 等待状态迁移完成,并最终进入StateStopped状态 + <-time.After(100 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) }) t.Run("TaskPool状态由Running变为Stopped", func(t *testing.T) { @@ -251,8 +251,8 @@ func TestTestPool_ShutdownNow(t *testing.T) { for i := 0; i < n; i++ { go func() { - tasks, err := pool.ShutdownNow() - c <- result{tasks: tasks, err: err} + tasks, er := pool.ShutdownNow() + c <- result{tasks: tasks, err: er} }() } @@ -377,15 +377,16 @@ func TestTaskPool__Closing_(t *testing.T) { assert.NoError(t, err) // 提交任务 - num := 10 - for i := 0; i < num; i++ { - go func() { - pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { - <-time.After(10 * time.Millisecond) - return nil - })}) - }() - } + //num := 10 + //for i := 0; i < num; i++ { + // go func() { + // err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + // <-time.After(10 * time.Millisecond) + // return nil + // })}) + // t.Log(err) + // }() + //} t.Run("Start", func(t *testing.T) { t.Parallel() @@ -400,22 +401,20 @@ func TestTaskPool__Closing_(t *testing.T) { num := 10 for i := 0; i < num; i++ { go func() { - pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil })}) + t.Log(err) }() } done, err := pool.Shutdown() assert.NoError(t, err) assert.ErrorIs(t, pool.Start(), ErrTaskPoolIsClosing) - select { - case <-done: - <-time.After(5 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) - return - } + <-done + <-time.After(10 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) }) t.Run("ShutdownNow", func(t *testing.T) { @@ -431,10 +430,11 @@ func TestTaskPool__Closing_(t *testing.T) { num := 10 for i := 0; i < num; i++ { go func() { - pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil })}) + t.Log(err) }() } @@ -445,18 +445,15 @@ func TestTaskPool__Closing_(t *testing.T) { assert.ErrorIs(t, err, ErrTaskPoolIsClosing) assert.Nil(t, tasks) - select { - case <-done: - <-time.After(50 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) - return - } + <-done + <-time.After(50 * time.Millisecond) + assert.Equal(t, StateStopped, pool.State()) }) } func TestTestPool__Stopped_(t *testing.T) { - + t.Parallel() n := 2 pool, _ := NewBlockQueueTaskPool(1, n) assert.Equal(t, StateCreated, pool.State()) @@ -467,14 +464,16 @@ func TestTestPool__Stopped_(t *testing.T) { // 模拟阻塞提交 for i := 0; i < 3*n; i++ { go func() { - pool.Submit(context.Background(), &FastTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), &FastTask{task: TaskFunc(func(ctx context.Context) error { <-time.After(2 * time.Millisecond) return nil })}) - pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + t.Log(err) + err = pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil })}) + t.Log(err) }() } @@ -609,7 +608,6 @@ func TestTaskExecutor(t *testing.T) { ex.SlowQueue() <- &SlowTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} panic("slow task ") - return nil })} } }() @@ -633,7 +631,6 @@ func TestTaskExecutor(t *testing.T) { ex.FastQueue() <- &FastTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} panic("fast task") - return nil })} } }() @@ -641,21 +638,18 @@ func TestTaskExecutor(t *testing.T) { // 等待任务开始执行 <-time.After(100 * time.Millisecond) - select { - case <-ex.Close(): - assert.Equal(t, n, 4*numTasks) - assert.Equal(t, int32(0), ex.NumRunningSlow()) - assert.Equal(t, int32(0), ex.NumRunningFast()) - close(resultChan) - num := 0 - for r := range resultChan { - if r == struct{}{} { - num++ - } + <-ex.Close() + assert.Equal(t, n, 4*numTasks) + assert.Equal(t, int32(0), ex.NumRunningSlow()) + assert.Equal(t, int32(0), ex.NumRunningFast()) + close(resultChan) + num := 0 + for r := range resultChan { + if r == struct{}{} { + num++ } - assert.Equal(t, n, num) - return } + assert.Equal(t, n, num) }) t.Run("强制关闭", func(t *testing.T) { @@ -679,7 +673,6 @@ func TestTaskExecutor(t *testing.T) { &SlowTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} panic("slow task ") - return nil })}, &SlowTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} @@ -690,7 +683,6 @@ func TestTaskExecutor(t *testing.T) { &SlowTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} panic("slow task ") - return nil })}, } @@ -703,7 +695,6 @@ func TestTaskExecutor(t *testing.T) { &FastTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} panic("fast task") - return nil })}, &FastTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} @@ -713,7 +704,6 @@ func TestTaskExecutor(t *testing.T) { &FastTask{task: TaskFunc(func(ctx context.Context) error { resultChan <- struct{}{} panic("fast task") - return nil })}, } go func() { From 340cb26d8353d2b74fbd07f7d2758cff54d80af3 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Sun, 28 Aug 2022 00:17:33 +0800 Subject: [PATCH 03/15] =?UTF-8?q?=E5=A4=84=E7=90=86=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E7=AD=89=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- pool/task_pool.go | 22 +++++++--------------- pool/task_pool_test.go | 15 +-------------- 2 files changed, 8 insertions(+), 29 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index e927fca7..d521cd9c 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -152,7 +152,7 @@ func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error { } if b.state.CompareAndSwap(StateCreated, b.locked) { - ok, err := b.submit(ctx, task, func() chan<- Task { return b.chanByTask(task) }) + ok, err := b.submitTask(ctx, task, func() chan<- Task { return b.chanByTask(task) }) if ok || err != nil { b.state.Swap(StateCreated) return err @@ -161,7 +161,7 @@ func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error { } if b.state.CompareAndSwap(StateRunning, b.locked) { - ok, err := b.submit(ctx, task, func() chan<- Task { return b.chanByTask(task) }) + ok, err := b.submitTask(ctx, task, func() chan<- Task { return b.chanByTask(task) }) if ok || err != nil { b.state.Swap(StateRunning) return err @@ -181,7 +181,7 @@ func (b *BlockQueueTaskPool) chanByTask(task Task) chan<- Task { } } -func (b *BlockQueueTaskPool) submit(ctx context.Context, task Task, channel func() chan<- Task) (ok bool, err error) { +func (_ *BlockQueueTaskPool) submitTask(ctx context.Context, task Task, channel func() chan<- Task) (ok bool, err error) { // 此处channel() <- task不会出现panic——因为channel被关闭而panic // 代码执行到submit时TaskPool处于lock状态 // 要关闭channel需要TaskPool处于RUNNING状态,Shutdown/ShutdownNow才能成功 @@ -220,7 +220,6 @@ func (b *BlockQueueTaskPool) Start() error { return nil } } - //return nil } // Shutdown 将会拒绝提交新的任务,但是会继续执行已提交任务 @@ -249,22 +248,15 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { if b.state.CompareAndSwap(StateRunning, StateClosing) { // todo: 等待task完成,关闭b.done // 监听done信号,然后完成状态迁移StateClosing -> StateStopped - b.done = b.taskExecutor.Close() // 需要注释掉close(b.done) - //close(b.done) + b.done = b.taskExecutor.Close() go func() { <-b.done b.state.CompareAndSwap(StateClosing, StateStopped) - //ok := b.state.CompareAndSwap(StateClosing, StateStopped) - //for !ok { - // ok = b.state.CompareAndSwap(StateClosing, StateStopped) - //} - //return }() return b.done, nil } } - //return nil, nil } // ShutdownNow 立刻关闭任务池,并且返回所有剩余未执行的任务(不包含正在执行的任务) @@ -337,7 +329,7 @@ func (t *TaskExecutor) startFastTasks() { } go func() { t.numFast.Add(1) - //log.Println("fast N", t.numFast.Add(1)) + // log.Println("fast N", t.numFast.Add(1)) defer func() { // 恢复统计 t.numFast.Add(-1) @@ -364,7 +356,7 @@ func (t *TaskExecutor) startSlowTasks() { atomic.AddInt32(&t.maxGo, 1) continue } - //log.Println("maxGo=", n) + // log.Println("maxGo=", n) select { case <-t.ctx.Done(): return @@ -375,7 +367,7 @@ func (t *TaskExecutor) startSlowTasks() { } go func() { t.numSlow.Add(1) - //log.Println("slow N=", t.numSlow.Add(1)) + // log.Println("slow N=", t.numSlow.Add(1)) defer func() { // 恢复 atomic.AddInt32(&t.maxGo, 1) diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 127678cf..e8625b9a 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -376,18 +376,6 @@ func TestTaskPool__Closing_(t *testing.T) { assert.Equal(t, StateRunning, pool.State()) assert.NoError(t, err) - // 提交任务 - //num := 10 - //for i := 0; i < num; i++ { - // go func() { - // err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { - // <-time.After(10 * time.Millisecond) - // return nil - // })}) - // t.Log(err) - // }() - //} - t.Run("Start", func(t *testing.T) { t.Parallel() @@ -478,7 +466,6 @@ func TestTestPool__Stopped_(t *testing.T) { } tasks, err := pool.ShutdownNow() - //assert.NotNil(t, tasks) assert.NoError(t, err) assert.Equal(t, StateStopped, pool.State()) @@ -724,7 +711,7 @@ func TestTaskExecutor(t *testing.T) { if r := recover(); r != nil { // 发送任务时,chan被关闭,也算作执行中 resultChan <- struct{}{} - //t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) + // t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) } }() for _, task := range fastTasks { From 1edcb7984ca8726e688097be638b17ac6892ada0 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Sun, 28 Aug 2022 00:23:07 +0800 Subject: [PATCH 04/15] =?UTF-8?q?=E7=9C=81=E7=95=A5submitTask=E6=8E=A5?= =?UTF-8?q?=E6=94=B6=E8=80=85=E5=90=8D=E7=A7=B0?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pool/task_pool.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index d521cd9c..224df333 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -181,7 +181,7 @@ func (b *BlockQueueTaskPool) chanByTask(task Task) chan<- Task { } } -func (_ *BlockQueueTaskPool) submitTask(ctx context.Context, task Task, channel func() chan<- Task) (ok bool, err error) { +func (*BlockQueueTaskPool) submitTask(ctx context.Context, task Task, channel func() chan<- Task) (ok bool, err error) { // 此处channel() <- task不会出现panic——因为channel被关闭而panic // 代码执行到submit时TaskPool处于lock状态 // 要关闭channel需要TaskPool处于RUNNING状态,Shutdown/ShutdownNow才能成功 From 9c67a2563a207975f1f99e193a082d187748e44d Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Sun, 28 Aug 2022 00:38:21 +0800 Subject: [PATCH 05/15] =?UTF-8?q?=E4=BF=AE=E6=94=B9CHANGELOG.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- .CHANGELOG.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index adf78e2c..3198a674 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -10,4 +10,4 @@ [ekit: 实现了 LinkedList Delete](https://github.com/gotomicro/ekit/pull/38) [ekit: 修复 Pool TestPool](https://github.com/gotomicro/ekit/pull/40) [ekit: 实现了 LinkedList Range](https://github.com/gotomicro/ekit/pull/46) -[ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/issues/50) \ No newline at end of file +[ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/pull/57) \ No newline at end of file From eb2e93a2853edac1551222d707bfee2d9104e2e7 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Sun, 28 Aug 2022 16:16:18 +0800 Subject: [PATCH 06/15] =?UTF-8?q?=E4=BF=AE=E5=A4=8D=E5=9B=A0=E4=BD=BF?= =?UTF-8?q?=E7=94=A8go1.19=E5=AF=BC=E8=87=B4=E7=9A=84atomic.Int32=E7=B1=BB?= =?UTF-8?q?=E5=9E=8B=E4=B8=BA=E5=AE=9A=E4=B9=89=E9=97=AE=E9=A2=98=20?= =?UTF-8?q?=E5=A4=84=E7=90=86Task.Run=E8=BF=94=E5=9B=9E=E7=9A=84error?= =?UTF-8?q?=EF=BC=8Cfmt.Println=E5=8F=AF=E8=83=BD=E5=AF=BC=E8=87=B4?= =?UTF-8?q?=E5=B9=B6=E5=8F=91=E9=97=AE=E9=A2=98?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- pool/task_pool.go | 78 +++++++++++++++++++++++++---------------------- 1 file changed, 42 insertions(+), 36 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index 224df333..e9a06073 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -90,7 +90,7 @@ func (s *SlowTask) Run(ctx context.Context) error { return s.task.Run(ctx) } // BlockQueueTaskPool 并发阻塞的任务池 type BlockQueueTaskPool struct { - state atomic.Int32 + state int32 numGo int lenQu int @@ -119,13 +119,13 @@ func NewBlockQueueTaskPool(concurrency int, queueSize int) (*BlockQueueTaskPool, slowTaskQueue: taskExecutor.SlowQueue(), locked: int32(101), } - b.state.Store(StateCreated) + atomic.StoreInt32(&b.state, StateCreated) return b, nil } func (b *BlockQueueTaskPool) State() int32 { for { - state := b.state.Load() + state := atomic.LoadInt32(&b.state) if state != b.locked { return state } @@ -143,30 +143,30 @@ func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error { // todo: 用户未设置超时,可以考虑内部给个超时提交 for { - if b.state.Load() == StateClosing { + if atomic.LoadInt32(&b.state) == StateClosing { return fmt.Errorf("%w", ErrTaskPoolIsClosing) } - if b.state.Load() == StateStopped { + if atomic.LoadInt32(&b.state) == StateStopped { return fmt.Errorf("%w", ErrTaskPoolIsStopped) } - if b.state.CompareAndSwap(StateCreated, b.locked) { + if atomic.CompareAndSwapInt32(&b.state, StateCreated, b.locked) { ok, err := b.submitTask(ctx, task, func() chan<- Task { return b.chanByTask(task) }) if ok || err != nil { - b.state.Swap(StateCreated) + atomic.SwapInt32(&b.state, StateCreated) return err } - b.state.Swap(StateCreated) + atomic.SwapInt32(&b.state, StateCreated) } - if b.state.CompareAndSwap(StateRunning, b.locked) { + if atomic.CompareAndSwapInt32(&b.state, StateRunning, b.locked) { ok, err := b.submitTask(ctx, task, func() chan<- Task { return b.chanByTask(task) }) if ok || err != nil { - b.state.Swap(StateRunning) + atomic.SwapInt32(&b.state, StateRunning) return err } - b.state.Swap(StateRunning) + atomic.SwapInt32(&b.state, StateRunning) } } } @@ -201,20 +201,20 @@ func (b *BlockQueueTaskPool) Start() error { for { - if b.state.Load() == StateClosing { + if atomic.LoadInt32(&b.state) == StateClosing { return fmt.Errorf("%w", ErrTaskPoolIsClosing) } - if b.state.Load() == StateStopped { + if atomic.LoadInt32(&b.state) == StateStopped { return fmt.Errorf("%w", ErrTaskPoolIsStopped) } - if b.state.Load() == StateRunning { + if atomic.LoadInt32(&b.state) == StateRunning { // 重复调用,返回缓存结果 return nil } - if b.state.CompareAndSwap(StateCreated, StateRunning) { + if atomic.CompareAndSwapInt32(&b.state, StateCreated, StateRunning) { // todo: 启动task调度器,开始执行task b.taskExecutor.Start() return nil @@ -230,28 +230,28 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { for { - if b.state.Load() == StateCreated { + if atomic.LoadInt32(&b.state) == StateCreated { return nil, fmt.Errorf("%w", ErrTaskPoolIsNotRunning) } - if b.state.Load() == StateStopped { + if atomic.LoadInt32(&b.state) == StateStopped { // 重复调用时,恰好前一个Shutdown调用将状态迁移为StateStopped // 这种情况与先调用ShutdownNow状态迁移为StateStopped再调用Shutdown效果一样 return nil, fmt.Errorf("%w", ErrTaskPoolIsStopped) } - if b.state.Load() == StateClosing { + if atomic.LoadInt32(&b.state) == StateClosing { // 重复调用,返回缓存结果 return b.done, nil } - if b.state.CompareAndSwap(StateRunning, StateClosing) { + if atomic.CompareAndSwapInt32(&b.state, StateRunning, StateClosing) { // todo: 等待task完成,关闭b.done // 监听done信号,然后完成状态迁移StateClosing -> StateStopped b.done = b.taskExecutor.Close() go func() { <-b.done - b.state.CompareAndSwap(StateClosing, StateStopped) + atomic.CompareAndSwapInt32(&b.state, StateClosing, StateStopped) }() return b.done, nil } @@ -264,22 +264,22 @@ func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { for { - if b.state.Load() == StateCreated { + if atomic.LoadInt32(&b.state) == StateCreated { return nil, fmt.Errorf("%w", ErrTaskPoolIsNotRunning) } - if b.state.Load() == StateClosing { + if atomic.LoadInt32(&b.state) == StateClosing { return nil, fmt.Errorf("%w", ErrTaskPoolIsClosing) } - if b.state.Load() == StateStopped { + if atomic.LoadInt32(&b.state) == StateStopped { // 重复调用,返回缓存结果 b.mux.RLock() tasks := append([]Task(nil), b.tasks...) b.mux.RUnlock() return tasks, nil } - if b.state.CompareAndSwap(StateRunning, StateStopped) { + if atomic.CompareAndSwapInt32(&b.state, StateRunning, StateStopped) { b.mux.Lock() b.tasks = b.taskExecutor.Stop() tasks := append([]Task(nil), b.tasks...) @@ -300,8 +300,8 @@ type TaskExecutor struct { ctx context.Context cancelFunc context.CancelFunc // 统计 - numSlow atomic.Int32 - numFast atomic.Int32 + numSlow int32 + numFast int32 } func NewTaskExecutor(maxGo int, queueSize int) *TaskExecutor { @@ -328,11 +328,11 @@ func (t *TaskExecutor) startFastTasks() { return } go func() { - t.numFast.Add(1) - // log.Println("fast N", t.numFast.Add(1)) + atomic.AddInt32(&t.numFast, 1) + // log.Println("fast N", atomic.AddInt32(&t.numFast, 1)) defer func() { // 恢复统计 - t.numFast.Add(-1) + atomic.AddInt32(&t.numFast, -1) // handle panic if r := recover(); r != nil { @@ -342,7 +342,10 @@ func (t *TaskExecutor) startFastTasks() { } }() // todo: handle err - fmt.Println(task.Run(t.ctx)) + err := task.Run(t.ctx) + if err != nil { + return + } }() } } @@ -366,12 +369,12 @@ func (t *TaskExecutor) startSlowTasks() { return } go func() { - t.numSlow.Add(1) + atomic.AddInt32(&t.numSlow, 1) // log.Println("slow N=", t.numSlow.Add(1)) defer func() { // 恢复 atomic.AddInt32(&t.maxGo, 1) - t.numSlow.Add(-1) + atomic.AddInt32(&t.numSlow, -1) // handle panic if r := recover(); r != nil { @@ -381,7 +384,10 @@ func (t *TaskExecutor) startSlowTasks() { } }() // todo: handle err - fmt.Println(task.Run(t.ctx)) + err := task.Run(t.ctx) + if err != nil { + return + } }() } } @@ -396,11 +402,11 @@ func (t *TaskExecutor) SlowQueue() chan<- Task { } func (t *TaskExecutor) NumRunningSlow() int32 { - return t.numSlow.Load() + return atomic.LoadInt32(&t.numSlow) } func (t *TaskExecutor) NumRunningFast() int32 { - return t.numFast.Load() + return atomic.LoadInt32(&t.numFast) } // Close 优雅关闭 @@ -420,7 +426,7 @@ func (t *TaskExecutor) Close() <-chan struct{} { for i := 0; i < 3; i++ { // 确保所有运行中任务也自然退出 - for t.numFast.Load() != 0 || t.numSlow.Load() != 0 { + for atomic.LoadInt32(&t.numFast) != 0 || atomic.LoadInt32(&t.numSlow) != 0 { time.Sleep(time.Second) } } From 29fa35825e0b7b6f23ac0d0a51edaf282a80d7d2 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Sun, 28 Aug 2022 19:12:19 +0800 Subject: [PATCH 07/15] =?UTF-8?q?=E9=87=8D=E6=9E=84TaskPool=E5=B9=B6?= =?UTF-8?q?=E6=95=B4=E7=90=86=E6=B5=8B=E8=AF=95?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- pool/task_pool.go | 448 ++++++++++++-------------------- pool/task_pool_test.go | 572 ++++++++++++++--------------------------- 2 files changed, 371 insertions(+), 649 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index e9a06073..f8634309 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -26,15 +26,18 @@ import ( ) var ( - StateCreated int32 = 1 - StateRunning int32 = 2 - StateClosing int32 = 3 - StateStopped int32 = 4 + stateCreated int32 = 1 + stateRunning int32 = 2 + stateClosing int32 = 3 + stateStopped int32 = 4 + stateLocked int32 = 5 - ErrTaskPoolIsNotRunning = errors.New("pool: TaskPool未运行") - ErrTaskPoolIsClosing = errors.New("pool:TaskPool关闭中") - ErrTaskPoolIsStopped = errors.New("pool: TaskPool已停止") - ErrTaskIsInvalid = errors.New("pool: Task非法") + errTaskPoolIsNotRunning = errors.New("ekit: TaskPool未运行") + errTaskPoolIsClosing = errors.New("ekit:TaskPool关闭中") + errTaskPoolIsStopped = errors.New("ekit: TaskPool已停止") + errTaskIsInvalid = errors.New("ekit: Task非法") + + errInvalidArgument = errors.New("ekit: 参数非法") _ TaskPool = &BlockQueueTaskPool{} panicBuffLen = 2048 @@ -78,119 +81,112 @@ type TaskFunc func(ctx context.Context) error // 超时控制取决于衍生出 TaskFunc 的方法 func (t TaskFunc) Run(ctx context.Context) error { return t(ctx) } -// FastTask 快任务,耗时较短,每个任务一个Goroutine -type FastTask struct{ task Task } - -func (f *FastTask) Run(ctx context.Context) error { return f.task.Run(ctx) } - -// SlowTask 慢任务,耗时较长,运行在固定个数Goroutine上 -type SlowTask struct{ task Task } +// taskWrapper 是Task的装饰器 +type taskWrapper struct { + t Task +} -func (s *SlowTask) Run(ctx context.Context) error { return s.task.Run(ctx) } +func (tw *taskWrapper) Run(ctx context.Context) error { + defer func() { + // 处理 panic + if r := recover(); r != nil { + buf := make([]byte, panicBuffLen) + buf = buf[:runtime.Stack(buf, false)] + fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) + } + }() + return tw.t.Run(ctx) +} // BlockQueueTaskPool 并发阻塞的任务池 type BlockQueueTaskPool struct { + // TaskPool内部状态 state int32 - numGo int - lenQu int - - taskExecutor *TaskExecutor - fastTaskQueue chan<- Task - slowTaskQueue chan<- Task - - // 缓存taskExecutor结果 - done <-chan struct{} - tasks []Task - mux sync.RWMutex - locked int32 + + queue chan Task + token chan struct{} + num int32 + + // 外部信号 + done chan struct{} + // 内部中断信号 + ctx context.Context + cancelFunc context.CancelFunc + // 缓存 + mux sync.RWMutex + submittedTasks []Task } // NewBlockQueueTaskPool 创建一个新的 BlockQueueTaskPool // concurrency 是并发数,即最多允许多少个 goroutine 执行任务 // queueSize 是队列大小,即最多有多少个任务在等待调度 func NewBlockQueueTaskPool(concurrency int, queueSize int) (*BlockQueueTaskPool, error) { - taskExecutor := NewTaskExecutor(concurrency, queueSize) + if concurrency < 1 { + return nil, fmt.Errorf("%w:concurrency应该大于0", errInvalidArgument) + } + if queueSize < 0 { + return nil, fmt.Errorf("%w:queueSize应该大于等于0", errInvalidArgument) + } b := &BlockQueueTaskPool{ - numGo: concurrency, - lenQu: queueSize, - done: make(chan struct{}), - taskExecutor: taskExecutor, - fastTaskQueue: taskExecutor.FastQueue(), - slowTaskQueue: taskExecutor.SlowQueue(), - locked: int32(101), + queue: make(chan Task, queueSize), + token: make(chan struct{}, concurrency), + done: make(chan struct{}), } - atomic.StoreInt32(&b.state, StateCreated) + b.ctx, b.cancelFunc = context.WithCancel(context.Background()) + atomic.StoreInt32(&b.state, stateCreated) return b, nil } -func (b *BlockQueueTaskPool) State() int32 { - for { - state := atomic.LoadInt32(&b.state) - if state != b.locked { - return state - } - } -} - // Submit 提交一个任务 // 如果此时队列已满,那么将会阻塞调用者。 // 如果因为 ctx 的原因返回,那么将会返回 ctx.Err() // 在调用 Start 前后都可以调用 Submit func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error { if task == nil || reflect.ValueOf(task).IsNil() { - return fmt.Errorf("%w", ErrTaskIsInvalid) + return fmt.Errorf("%w", errTaskIsInvalid) } // todo: 用户未设置超时,可以考虑内部给个超时提交 for { - if atomic.LoadInt32(&b.state) == StateClosing { - return fmt.Errorf("%w", ErrTaskPoolIsClosing) + if atomic.LoadInt32(&b.state) == stateClosing { + return fmt.Errorf("%w", errTaskPoolIsClosing) } - if atomic.LoadInt32(&b.state) == StateStopped { - return fmt.Errorf("%w", ErrTaskPoolIsStopped) + if atomic.LoadInt32(&b.state) == stateStopped { + return fmt.Errorf("%w", errTaskPoolIsStopped) } - if atomic.CompareAndSwapInt32(&b.state, StateCreated, b.locked) { - ok, err := b.submitTask(ctx, task, func() chan<- Task { return b.chanByTask(task) }) - if ok || err != nil { - atomic.SwapInt32(&b.state, StateCreated) - return err - } - atomic.SwapInt32(&b.state, StateCreated) + task = &taskWrapper{t: task} + + ok, err := b.trySubmit(ctx, task, stateCreated) + if ok || err != nil { + return err } - if atomic.CompareAndSwapInt32(&b.state, StateRunning, b.locked) { - ok, err := b.submitTask(ctx, task, func() chan<- Task { return b.chanByTask(task) }) - if ok || err != nil { - atomic.SwapInt32(&b.state, StateRunning) - return err - } - atomic.SwapInt32(&b.state, StateRunning) + ok, err = b.trySubmit(ctx, task, stateRunning) + if ok || err != nil { + return err } } } -func (b *BlockQueueTaskPool) chanByTask(task Task) chan<- Task { - switch task.(type) { - case *SlowTask: - return b.slowTaskQueue - default: - // FastTask, TaskFunc, 用户自定义类型实现Task接口 - return b.fastTaskQueue - } -} +func (b *BlockQueueTaskPool) trySubmit(ctx context.Context, task Task, state int32) (bool, error) { + // 进入临界区 + if atomic.CompareAndSwapInt32(&b.state, state, stateLocked) { + defer atomic.CompareAndSwapInt32(&b.state, stateLocked, state) -func (*BlockQueueTaskPool) submitTask(ctx context.Context, task Task, channel func() chan<- Task) (ok bool, err error) { - // 此处channel() <- task不会出现panic——因为channel被关闭而panic - // 代码执行到submit时TaskPool处于lock状态 - // 要关闭channel需要TaskPool处于RUNNING状态,Shutdown/ShutdownNow才能成功 - select { - case <-ctx.Done(): - return false, fmt.Errorf("%w", ctx.Err()) - case channel() <- task: - return true, nil - default: + // 此处b.queue <- task不会因为b.queue被关闭而panic + // 代码执行到trySubmit时TaskPool处于lock状态 + // 要关闭b.queue需要TaskPool处于RUNNING状态,Shutdown/ShutdownNow才能成功 + select { + case <-ctx.Done(): + return false, fmt.Errorf("%w", ctx.Err()) + case b.queue <- task: + return true, nil + default: + // 不能阻塞在临界区 + } + return false, nil } return false, nil } @@ -201,27 +197,59 @@ func (b *BlockQueueTaskPool) Start() error { for { - if atomic.LoadInt32(&b.state) == StateClosing { - return fmt.Errorf("%w", ErrTaskPoolIsClosing) + if atomic.LoadInt32(&b.state) == stateClosing { + return fmt.Errorf("%w", errTaskPoolIsClosing) } - if atomic.LoadInt32(&b.state) == StateStopped { - return fmt.Errorf("%w", ErrTaskPoolIsStopped) + if atomic.LoadInt32(&b.state) == stateStopped { + return fmt.Errorf("%w", errTaskPoolIsStopped) } - if atomic.LoadInt32(&b.state) == StateRunning { - // 重复调用,返回缓存结果 + if atomic.LoadInt32(&b.state) == stateRunning { + // 重复调用,不予处理 return nil } - if atomic.CompareAndSwapInt32(&b.state, StateCreated, StateRunning) { - // todo: 启动task调度器,开始执行task - b.taskExecutor.Start() + if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateRunning) { + go b.startTasks() return nil } } } +func (b *BlockQueueTaskPool) startTasks() { + defer close(b.token) + + for { + select { + case <-b.ctx.Done(): + return + case b.token <- struct{}{}: + + task := <-b.queue + // handle close(b.queue) + if task == nil { + return + } + + go func() { + + atomic.AddInt32(&b.num, 1) + defer func() { + atomic.AddInt32(&b.num, -1) + <-b.token + }() + + // todo: handle err + err := task.Run(b.ctx) + if err != nil { + return + } + }() + } + } +} + // Shutdown 将会拒绝提交新的任务,但是会继续执行已提交任务 // 当执行完毕后,会往返回的 chan 中丢入信号 // Shutdown 会负责关闭返回的 chan @@ -230,28 +258,38 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { for { - if atomic.LoadInt32(&b.state) == StateCreated { - return nil, fmt.Errorf("%w", ErrTaskPoolIsNotRunning) + if atomic.LoadInt32(&b.state) == stateCreated { + return nil, fmt.Errorf("%w", errTaskPoolIsNotRunning) } - if atomic.LoadInt32(&b.state) == StateStopped { + if atomic.LoadInt32(&b.state) == stateStopped { // 重复调用时,恰好前一个Shutdown调用将状态迁移为StateStopped // 这种情况与先调用ShutdownNow状态迁移为StateStopped再调用Shutdown效果一样 - return nil, fmt.Errorf("%w", ErrTaskPoolIsStopped) + return nil, fmt.Errorf("%w", errTaskPoolIsStopped) } - if atomic.LoadInt32(&b.state) == StateClosing { - // 重复调用,返回缓存结果 + if atomic.LoadInt32(&b.state) == stateClosing { + // 重复调用 return b.done, nil } - if atomic.CompareAndSwapInt32(&b.state, StateRunning, StateClosing) { - // todo: 等待task完成,关闭b.done - // 监听done信号,然后完成状态迁移StateClosing -> StateStopped - b.done = b.taskExecutor.Close() + if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateClosing) { + // 目标:不但希望正在运行中的任务自然退出,还希望队列中等待的任务也能启动执行并自然退出 + // 策略:先将队列中的任务启动并执行(清空队列),再等待全部运行中的任务自然退出 + + // 先关闭等待队列不再允许提交 + // 同时任务启动循环能够通过Task==nil来终止循环 + close(b.queue) + go func() { - <-b.done - atomic.CompareAndSwapInt32(&b.state, StateClosing, StateStopped) + // 等待运行中的Task自然结束 + for atomic.LoadInt32(&b.num) != 0 { + time.Sleep(time.Second) + } + // 通知外部调用者 + close(b.done) + // 完成最终的状态迁移 + atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped) }() return b.done, nil } @@ -264,198 +302,54 @@ func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { for { - if atomic.LoadInt32(&b.state) == StateCreated { - return nil, fmt.Errorf("%w", ErrTaskPoolIsNotRunning) + if atomic.LoadInt32(&b.state) == stateCreated { + return nil, fmt.Errorf("%w", errTaskPoolIsNotRunning) } - if atomic.LoadInt32(&b.state) == StateClosing { - return nil, fmt.Errorf("%w", ErrTaskPoolIsClosing) + if atomic.LoadInt32(&b.state) == stateClosing { + return nil, fmt.Errorf("%w", errTaskPoolIsClosing) } - if atomic.LoadInt32(&b.state) == StateStopped { + if atomic.LoadInt32(&b.state) == stateStopped { // 重复调用,返回缓存结果 b.mux.RLock() - tasks := append([]Task(nil), b.tasks...) + tasks := append([]Task(nil), b.submittedTasks...) b.mux.RUnlock() return tasks, nil } - if atomic.CompareAndSwapInt32(&b.state, StateRunning, StateStopped) { - b.mux.Lock() - b.tasks = b.taskExecutor.Stop() - tasks := append([]Task(nil), b.tasks...) - b.mux.Unlock() - return tasks, nil - } - } -} - -type TaskExecutor struct { - slowTasks chan Task - fastTasks chan Task - - maxGo int32 - // - done chan struct{} - // 利用ctx充当内部信号 - ctx context.Context - cancelFunc context.CancelFunc - // 统计 - numSlow int32 - numFast int32 -} + if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateStopped) { + // 目标:立刻关闭并且返回所有剩下未执行的任务 + // 策略:关闭等待队列不再接受新任务,中断任务启动循环,清空等待队列并保存返回 -func NewTaskExecutor(maxGo int, queueSize int) *TaskExecutor { - t := &TaskExecutor{maxGo: int32(maxGo), done: make(chan struct{})} - t.ctx, t.cancelFunc = context.WithCancel(context.Background()) - t.slowTasks = make(chan Task, queueSize) - t.fastTasks = make(chan Task, queueSize) - return t -} + close(b.queue) -func (t *TaskExecutor) Start() { - go t.startSlowTasks() - go t.startFastTasks() -} + // 发送中断信号,中断任务启动循环 + b.cancelFunc() -func (t *TaskExecutor) startFastTasks() { - for { - select { - case <-t.ctx.Done(): - return - case task := <-t.fastTasks: - // handle close(t.fastTasks) - if task == nil { - return + b.mux.Lock() + // 清空队列并保存 + var tasks []Task + for task := range b.queue { + b.submittedTasks = append(b.submittedTasks, task) + tasks = append(tasks, task) } - go func() { - atomic.AddInt32(&t.numFast, 1) - // log.Println("fast N", atomic.AddInt32(&t.numFast, 1)) - defer func() { - // 恢复统计 - atomic.AddInt32(&t.numFast, -1) - - // handle panic - if r := recover(); r != nil { - buf := make([]byte, panicBuffLen) - buf = buf[:runtime.Stack(buf, false)] - fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) - } - }() - // todo: handle err - err := task.Run(t.ctx) - if err != nil { - return - } - }() + b.mux.Unlock() + + return tasks, nil } } } -func (t *TaskExecutor) startSlowTasks() { - +// internalState 用于查看TaskPool状态 +func (b *BlockQueueTaskPool) internalState() int32 { for { - n := atomic.AddInt32(&t.maxGo, -1) - if n < 0 { - atomic.AddInt32(&t.maxGo, 1) - continue - } - // log.Println("maxGo=", n) - select { - case <-t.ctx.Done(): - return - case task := <-t.slowTasks: - // handle close(t.slowTasks) - if task == nil { - return - } - go func() { - atomic.AddInt32(&t.numSlow, 1) - // log.Println("slow N=", t.numSlow.Add(1)) - defer func() { - // 恢复 - atomic.AddInt32(&t.maxGo, 1) - atomic.AddInt32(&t.numSlow, -1) - - // handle panic - if r := recover(); r != nil { - buf := make([]byte, panicBuffLen) - buf = buf[:runtime.Stack(buf, false)] - fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) - } - }() - // todo: handle err - err := task.Run(t.ctx) - if err != nil { - return - } - }() + state := atomic.LoadInt32(&b.state) + if state != stateLocked { + return state } } } -func (t *TaskExecutor) FastQueue() chan<- Task { - return t.fastTasks -} - -func (t *TaskExecutor) SlowQueue() chan<- Task { - return t.slowTasks -} - -func (t *TaskExecutor) NumRunningSlow() int32 { - return atomic.LoadInt32(&t.numSlow) -} - -func (t *TaskExecutor) NumRunningFast() int32 { - return atomic.LoadInt32(&t.numFast) -} - -// Close 优雅关闭 -// 目标:不但希望正在运行中的任务自然退出,还希望队列中等待的任务也能启动执行并自然退出 -// 策略:先将所有队列中的任务启动并执行(清空队列),再等待全部运行中的任务自然退出。 -func (t *TaskExecutor) Close() <-chan struct{} { - - // 先关闭等待队列不再允许提交 - // 同时任务启动循环能够通过Task==nil来终止循环 - close(t.slowTasks) - close(t.fastTasks) - - go func() { - - // 检查三次是因为可能出现: - // 两队列中有任务且正在创建启动任务尚未执行计数,恰巧此时正在运行中的任务为0 - for i := 0; i < 3; i++ { - - // 确保所有运行中任务也自然退出 - for atomic.LoadInt32(&t.numFast) != 0 || atomic.LoadInt32(&t.numSlow) != 0 { - time.Sleep(time.Second) - } - } - - // 通知外部调用者 - close(t.done) - }() - - return t.done -} - -// Stop 强制关闭 -// 目标:立刻关闭并且返回所有剩下未执行的任务 -// 策略:关闭等待队列不再接受新任务,中断任务启动循环,清空等待队列并保存返回 -func (t *TaskExecutor) Stop() []Task { - - close(t.fastTasks) - close(t.slowTasks) - - // 发送中断信号,中断任务启动循环 - t.cancelFunc() - - // 清空队列并保存 - var tasks []Task - for task := range t.fastTasks { - tasks = append(tasks, task) - } - for task := range t.slowTasks { - tasks = append(tasks, task) - } - return tasks +func (b *BlockQueueTaskPool) NumGo() int32 { + return atomic.LoadInt32(&b.num) } diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index e8625b9a..582b9874 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -16,7 +16,7 @@ package pool import ( "context" - "fmt" + "errors" "sync" "testing" "time" @@ -37,6 +37,35 @@ New() --> CREATED -- Start() ---> RUNNING -- -- ShutdownNow() ---> STOPPED -- ShutdownNow() --> STOPPED */ +func TestTaskPool_New(t *testing.T) { + t.Parallel() + + pool, err := NewBlockQueueTaskPool(1, -1) + assert.ErrorIs(t, err, errInvalidArgument) + assert.Nil(t, pool) + + pool, err = NewBlockQueueTaskPool(1, 0) + assert.NoError(t, err) + assert.NotNil(t, pool) + + pool, err = NewBlockQueueTaskPool(1, 1) + assert.NoError(t, err) + assert.NotNil(t, pool) + + pool, err = NewBlockQueueTaskPool(-1, 1) + assert.ErrorIs(t, err, errInvalidArgument) + assert.Nil(t, pool) + + pool, err = NewBlockQueueTaskPool(0, 1) + assert.ErrorIs(t, err, errInvalidArgument) + assert.Nil(t, pool) + + pool, err = NewBlockQueueTaskPool(1, 1) + assert.NoError(t, err) + assert.NotNil(t, pool) + +} + func TestTaskPool_Start(t *testing.T) { /* todo: Start @@ -51,208 +80,193 @@ func TestTaskPool_Start(t *testing.T) { t.Parallel() pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) - errChan := make(chan error) - go func() { - // 多次运行结果一直 - var wg sync.WaitGroup - for i := 0; i < 5; i++ { - wg.Add(1) - go func() { - errChan <- pool.Start() - wg.Done() - }() - } - wg.Wait() - close(errChan) - }() + assert.Equal(t, stateCreated, pool.internalState()) + + n := 5 + errChan := make(chan error, n) + + // 第一次调用 + assert.NoError(t, pool.Start()) + assert.Equal(t, stateRunning, pool.internalState()) + + // 多次调用 + var wg sync.WaitGroup + for i := 0; i < n; i++ { + wg.Add(1) + go func() { + errChan <- pool.Start() + wg.Done() + }() + } + wg.Wait() + close(errChan) for err := range errChan { assert.NoError(t, err) } - assert.Equal(t, StateRunning, pool.State()) + assert.Equal(t, stateRunning, pool.internalState()) } func TestTaskPool_Submit(t *testing.T) { t.Parallel() - //todo: Submit - // TaskPoolRunning Shutdown/ShutdownNow前 + // todo: Submit // 在TaskPool所有状态中都可以提交,有的成功/阻塞,有的立即失败。 - // [x]ctx时间内提交成功 + // [x]ctx时间内提交成功 // [x]ctx超时,提交失败给出错误信息 // [x] 监听状态变化,从running->closing/stopped - // [x] Shutdown后,状态变迁,需要检查出并报错 ErrTaskPoolIsClosing - // [x] ShutdownNow后状态变迁,需要检查出并报错,ErrTaskPoolIsStopped + // [x] Shutdown后,状态变迁,需要检查出并报错 errTaskPoolIsClosing + // [x] ShutdownNow后状态变迁,需要检查出并报错,errTaskPoolIsStopped t.Run("提交Task阻塞", func(t *testing.T) { t.Parallel() t.Run("TaskPool状态由Created变为Running", func(t *testing.T) { t.Parallel() - // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) + + // 与下方 testSubmitBlockingAndTimeout 并发执行 errChan := make(chan error) go func() { + <-time.After(1 * time.Millisecond) errChan <- pool.Start() }() + + assert.Equal(t, stateCreated, pool.internalState()) + testSubmitBlockingAndTimeout(t, pool) + assert.NoError(t, <-errChan) - assert.Equal(t, StateRunning, pool.State()) + assert.Equal(t, stateRunning, pool.internalState()) }) t.Run("TaskPool状态由Running变为Closing", func(t *testing.T) { t.Parallel() - // 为了准确模拟,内部自定义一个pool - pool, _ := NewBlockQueueTaskPool(1, 2) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.NoError(t, err) - assert.Equal(t, StateRunning, pool.State()) + + pool := testNewRunningStateTaskPool(t, 1, 2) // 模拟阻塞提交 n := 10 - submitErrChan := make(chan error, 1) + firstSubmitErrChan := make(chan error, 1) for i := 0; i < n; i++ { go func() { - err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil - })}) + })) if err != nil { - submitErrChan <- err + firstSubmitErrChan <- err } }() } // 调用Shutdown使TaskPool状态发生迁移 - type Result struct { + type ShutdownResult struct { done <-chan struct{} err error } - resultChan := make(chan Result) + resultChan := make(chan ShutdownResult) go func() { <-time.After(time.Millisecond) done, err := pool.Shutdown() - resultChan <- Result{done: done, err: err} + resultChan <- ShutdownResult{done: done, err: err} }() r := <-resultChan // 阻塞中的任务报错,证明处于TaskPool处于StateClosing状态 - assert.ErrorIs(t, <-submitErrChan, ErrTaskPoolIsClosing) + assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsClosing) // Shutdown调用成功 assert.NoError(t, r.err) - <-r.done // 等待状态迁移完成,并最终进入StateStopped状态 <-time.After(100 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) + assert.Equal(t, stateStopped, pool.internalState()) }) t.Run("TaskPool状态由Running变为Stopped", func(t *testing.T) { t.Parallel() - // 为了准确模拟,内部自定义一个pool - pool, _ := NewBlockQueueTaskPool(1, 2) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.NoError(t, err) - assert.Equal(t, StateRunning, pool.State()) + + pool := testNewRunningStateTaskPool(t, 1, 2) // 模拟阻塞提交 n := 5 - submitErrChan := make(chan error, 1) + firstSubmitErrChan := make(chan error, 1) for i := 0; i < n; i++ { go func() { - err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil - })}) + })) if err != nil { - submitErrChan <- err + firstSubmitErrChan <- err } }() } // 并发调用ShutdownNow - type Result struct { - tasks []Task - err error - } - result := make(chan Result, 1) + + result := make(chan ShutdownNowResult, 1) go func() { <-time.After(time.Millisecond) tasks, err := pool.ShutdownNow() - result <- Result{tasks: tasks, err: err} + result <- ShutdownNowResult{tasks: tasks, err: err} }() r := <-result assert.NoError(t, r.err) - assert.NotNil(t, r.tasks) + assert.NotEmpty(t, r.tasks) - assert.ErrorIs(t, <-submitErrChan, ErrTaskPoolIsStopped) - assert.Equal(t, StateStopped, pool.State()) + assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) }) }) - } func TestTaskPool_Shutdown(t *testing.T) { t.Parallel() - pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) + pool := testNewRunningStateTaskPool(t, 1, 1) // 第一次调用 done, err := pool.Shutdown() assert.NoError(t, err) - // 第二次调用 select { case <-done: break default: + // 第二次调用 done2, err2 := pool.Shutdown() assert.Equal(t, done2, done) assert.Equal(t, err2, err) + assert.Equal(t, stateClosing, pool.internalState()) } - <-time.After(5 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) + <-done + assert.Equal(t, stateStopped, pool.internalState()) // 第一个Shutdown将状态迁移至StateStopped // 第三次调用 done, err = pool.Shutdown() assert.Nil(t, done) - assert.ErrorIs(t, err, ErrTaskPoolIsStopped) + assert.ErrorIs(t, err, errTaskPoolIsStopped) } func TestTestPool_ShutdownNow(t *testing.T) { t.Parallel() - pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) + pool := testNewRunningStateTaskPool(t, 1, 1) - type result struct { - tasks []Task - err error - } n := 3 - c := make(chan result, n) + c := make(chan ShutdownNowResult, n) for i := 0; i < n; i++ { go func() { - tasks, er := pool.ShutdownNow() - c <- result{tasks: tasks, err: er} + c <- ShutdownNowResult{tasks: tasks, err: er} }() } @@ -260,41 +274,46 @@ func TestTestPool_ShutdownNow(t *testing.T) { r := <-c assert.Nil(t, r.tasks) assert.NoError(t, r.err) - assert.Equal(t, StateStopped, pool.State()) + assert.Equal(t, stateStopped, pool.internalState()) } } func TestTaskPool__Created_(t *testing.T) { t.Parallel() - n, q := 1, 1 - pool, err := NewBlockQueueTaskPool(n, q) + pool, err := NewBlockQueueTaskPool(1, 1) assert.NoError(t, err) assert.NotNil(t, pool) - assert.Equal(t, StateCreated, pool.State()) + assert.Equal(t, stateCreated, pool.internalState()) t.Run("Submit", func(t *testing.T) { t.Parallel() t.Run("提交非法Task", func(t *testing.T) { t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 1) + assert.Equal(t, stateCreated, pool.internalState()) testSubmitInvalidTask(t, pool) - assert.Equal(t, StateCreated, pool.State()) + assert.Equal(t, stateCreated, pool.internalState()) }) t.Run("正常提交Task", func(t *testing.T) { t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 3) + assert.Equal(t, stateCreated, pool.internalState()) testSubmitValidTask(t, pool) - assert.Equal(t, StateCreated, pool.State()) + assert.Equal(t, stateCreated, pool.internalState()) }) t.Run("阻塞提交并导致超时", func(t *testing.T) { t.Parallel() - // 为了准确模拟,内部自定义一个pool + pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) + assert.Equal(t, stateCreated, pool.internalState()) testSubmitBlockingAndTimeout(t, pool) - assert.Equal(t, StateCreated, pool.State()) + assert.Equal(t, stateCreated, pool.internalState()) }) }) @@ -303,8 +322,8 @@ func TestTaskPool__Created_(t *testing.T) { done, err := pool.Shutdown() assert.Nil(t, done) - assert.ErrorIs(t, err, ErrTaskPoolIsNotRunning) - assert.Equal(t, StateCreated, pool.State()) + assert.ErrorIs(t, err, errTaskPoolIsNotRunning) + assert.Equal(t, stateCreated, pool.internalState()) }) t.Run("ShutdownNow", func(t *testing.T) { @@ -312,8 +331,8 @@ func TestTaskPool__Created_(t *testing.T) { tasks, err := pool.ShutdownNow() assert.Nil(t, tasks) - assert.ErrorIs(t, err, ErrTaskPoolIsNotRunning) - assert.Equal(t, StateCreated, pool.State()) + assert.ErrorIs(t, err, errTaskPoolIsNotRunning) + assert.Equal(t, stateCreated, pool.internalState()) }) } @@ -321,18 +340,12 @@ func TestTaskPool__Created_(t *testing.T) { func TestTaskPool__Running_(t *testing.T) { t.Parallel() - pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) - t.Run("Start", func(t *testing.T) { t.Parallel() - err = pool.Start() - // todo: 调度器只启动一次 - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) + + pool := testNewRunningStateTaskPool(t, 1, 3) + assert.NoError(t, pool.Start()) + assert.Equal(t, stateRunning, pool.internalState()) }) t.Run("Submit", func(t *testing.T) { @@ -340,88 +353,74 @@ func TestTaskPool__Running_(t *testing.T) { t.Run("提交非法Task", func(t *testing.T) { t.Parallel() + pool := testNewRunningStateTaskPool(t, 1, 1) testSubmitInvalidTask(t, pool) - assert.Equal(t, StateRunning, pool.State()) + assert.Equal(t, stateRunning, pool.internalState()) }) t.Run("正常提交Task", func(t *testing.T) { t.Parallel() + pool := testNewRunningStateTaskPool(t, 1, 3) testSubmitValidTask(t, pool) - assert.Equal(t, StateRunning, pool.State()) + assert.Equal(t, stateRunning, pool.internalState()) }) t.Run("阻塞提交并导致超时", func(t *testing.T) { t.Parallel() - // 为了准确模拟,内部自定义一个pool - pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) - + pool := testNewRunningStateTaskPool(t, 1, 1) testSubmitBlockingAndTimeout(t, pool) - - assert.Equal(t, StateRunning, pool.State()) + assert.Equal(t, stateRunning, pool.internalState()) }) }) } func TestTaskPool__Closing_(t *testing.T) { - t.Parallel() - pool, _ := NewBlockQueueTaskPool(1, 0) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) - t.Run("Start", func(t *testing.T) { t.Parallel() - pool, _ := NewBlockQueueTaskPool(1, 10) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) + num := 10 + pool := testNewRunningStateTaskPool(t, 2, num) // 提交任务 - num := 10 for i := 0; i < num; i++ { go func() { - err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil - })}) + })) t.Log(err) }() } done, err := pool.Shutdown() assert.NoError(t, err) - assert.ErrorIs(t, pool.Start(), ErrTaskPoolIsClosing) + + select { + case <-done: + default: + assert.ErrorIs(t, pool.Start(), errTaskPoolIsClosing) + } + <-done <-time.After(10 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) + assert.Equal(t, stateStopped, pool.internalState()) }) t.Run("ShutdownNow", func(t *testing.T) { t.Parallel() - pool, _ := NewBlockQueueTaskPool(1, 0) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) + pool := testNewRunningStateTaskPool(t, 1, 0) // 提交任务 num := 10 for i := 0; i < num; i++ { go func() { - err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { <-time.After(10 * time.Millisecond) return nil - })}) + })) t.Log(err) }() } @@ -429,57 +428,60 @@ func TestTaskPool__Closing_(t *testing.T) { done, err := pool.Shutdown() assert.NoError(t, err) - tasks, err := pool.ShutdownNow() - assert.ErrorIs(t, err, ErrTaskPoolIsClosing) - assert.Nil(t, tasks) + select { + case <-done: + default: + tasks, err := pool.ShutdownNow() + assert.ErrorIs(t, err, errTaskPoolIsClosing) + assert.Nil(t, tasks) + } <-done <-time.After(50 * time.Millisecond) - assert.Equal(t, StateStopped, pool.State()) + assert.Equal(t, stateStopped, pool.internalState()) }) } func TestTestPool__Stopped_(t *testing.T) { t.Parallel() - n := 2 - pool, _ := NewBlockQueueTaskPool(1, n) - assert.Equal(t, StateCreated, pool.State()) - err := pool.Start() - assert.Equal(t, StateRunning, pool.State()) - assert.NoError(t, err) + + concurrency, n := 2, 20 + pool := testNewRunningStateTaskPool(t, concurrency, n) // 模拟阻塞提交 - for i := 0; i < 3*n; i++ { + for i := 0; i < n; i++ { go func() { - err := pool.Submit(context.Background(), &FastTask{task: TaskFunc(func(ctx context.Context) error { - <-time.After(2 * time.Millisecond) + t.Log(pool.NumGo()) + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + <-time.After(time.Second) return nil - })}) - t.Log(err) - err = pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { - <-time.After(10 * time.Millisecond) - return nil - })}) + })) t.Log(err) }() } + <-time.After(100 * time.Millisecond) + assert.Equal(t, int32(concurrency), pool.NumGo()) + tasks, err := pool.ShutdownNow() assert.NoError(t, err) - assert.Equal(t, StateStopped, pool.State()) + assert.NotEmpty(t, tasks) + assert.Equal(t, stateStopped, pool.internalState()) t.Run("Start", func(t *testing.T) { t.Parallel() - assert.ErrorIs(t, pool.Start(), ErrTaskPoolIsStopped) - assert.Equal(t, StateStopped, pool.State()) + + assert.ErrorIs(t, pool.Start(), errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) }) t.Run("Submit", func(t *testing.T) { t.Parallel() + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { return nil })) - assert.ErrorIs(t, err, ErrTaskPoolIsStopped) - assert.Equal(t, StateStopped, pool.State()) + assert.ErrorIs(t, err, errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) }) t.Run("Shutdown", func(t *testing.T) { @@ -487,24 +489,28 @@ func TestTestPool__Stopped_(t *testing.T) { done, err := pool.Shutdown() assert.Nil(t, done) - assert.ErrorIs(t, err, ErrTaskPoolIsStopped) - assert.Equal(t, StateStopped, pool.State()) + assert.ErrorIs(t, err, errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) }) t.Run("ShutdownNow", func(t *testing.T) { t.Parallel() + // 多次调用返回相同结果 tasks2, err := pool.ShutdownNow() assert.NoError(t, err) + assert.NotEmpty(t, tasks2) assert.Equal(t, tasks2, tasks) + assert.Equal(t, stateStopped, pool.internalState()) }) } func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { - err := pool.Submit(context.Background(), &SlowTask{task: TaskFunc(func(ctx context.Context) error { + + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { <-time.After(2 * time.Millisecond) return nil - })}) + })) assert.NoError(t, err) n := 2 @@ -513,234 +519,56 @@ func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { go func() { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() - err := pool.Submit(ctx, &SlowTask{task: TaskFunc(func(ctx context.Context) error { + err := pool.Submit(ctx, TaskFunc(func(ctx context.Context) error { <-time.After(2 * time.Millisecond) return nil - })}) + })) if err != nil { errChan <- err } }() } + assert.ErrorIs(t, <-errChan, context.DeadlineExceeded) } func testSubmitValidTask(t *testing.T, pool *BlockQueueTaskPool) { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { return nil })) assert.NoError(t, err) + + err = pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { panic("task panic") })) + assert.NoError(t, err) + + err = pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { return errors.New("fake error") })) + assert.NoError(t, err) } func testSubmitInvalidTask(t *testing.T, pool *BlockQueueTaskPool) { - invalidTasks := map[string]Task{"*SlowTask": (*SlowTask)(nil), "*FastTask": (*FastTask)(nil), "nil": nil, "TaskFunc(nil)": TaskFunc(nil)} + + invalidTasks := map[string]Task{"*FakeTask": (*FakeTask)(nil), "nil": nil, "TaskFunc(nil)": TaskFunc(nil)} for name, task := range invalidTasks { t.Run(name, func(t *testing.T) { err := pool.Submit(context.Background(), task) - assert.ErrorIs(t, err, ErrTaskIsInvalid) + assert.ErrorIs(t, err, errTaskIsInvalid) }) } } -func TestTaskExecutor(t *testing.T) { - t.Parallel() - - /* todo: - [x]快慢任务分离 - [x]快任务没有goroutine限制,提供方法检查个数 - [x]慢任务占用固定个数goroutine,提供方法检查个数 - [x]任务的panic处理 - []任务的error处理 - todo: [x]Closing-优雅关闭 - [x]返回一个chan, - 供调用者监听,调用者从chan拿到消息即表明所有任务结束 - [x]等待任务自然结束 - [x]关闭chan——自动终止启动循环 - [x]将队列中等待的任务启动执行 - [x]等待未完成任务; - [x]Stop-强制关闭 - [x] 关闭chan - [x] 终止任务启动循环 - [x] 将队列中清空并未完成的任务返回 - */ - - t.Run("优雅关闭", func(t *testing.T) { - - t.Parallel() - - maxGo, numTasks := 2, 5 - n := 4 * numTasks - - ex := NewTaskExecutor(maxGo, n) - ex.Start() - - // 注意:添加Task后需要调整否者会阻塞 - resultChan := make(chan struct{}, n) - - go func() { - // chan may be closed - defer func() { - if r := recover(); r != nil { - // 发送失败,也算执行了 - resultChan <- struct{}{} - t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) - } - }() - for i := 0; i < numTasks; i++ { - ex.SlowQueue() <- &SlowTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - <-time.After(5 * time.Millisecond) - return nil - })} - // panic slow task - ex.SlowQueue() <- &SlowTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - panic("slow task ") - })} - } - }() - go func() { - // chan() may be closed - defer func() { - if r := recover(); r != nil { - // 发送失败,也算执行了 - resultChan <- struct{}{} - t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) - } - }() - for i := 0; i < numTasks; i++ { - ex.FastQueue() <- &FastTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - <-time.After(2 * time.Millisecond) - return nil - })} - - // panic fast task - ex.FastQueue() <- &FastTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - panic("fast task") - })} - } - }() - - // 等待任务开始执行 - <-time.After(100 * time.Millisecond) - - <-ex.Close() - assert.Equal(t, n, 4*numTasks) - assert.Equal(t, int32(0), ex.NumRunningSlow()) - assert.Equal(t, int32(0), ex.NumRunningFast()) - close(resultChan) - num := 0 - for r := range resultChan { - if r == struct{}{} { - num++ - } - } - assert.Equal(t, n, num) - }) - - t.Run("强制关闭", func(t *testing.T) { - t.Parallel() - - maxGo, numTasks := 2, 5 - ex := NewTaskExecutor(maxGo, numTasks) - ex.Start() - - // 注意:确保n = len(slowTasks) + len(fastTasks) - n := 8 - resultChan := make(chan struct{}, n) - - slowTasks := []Task{ - &SlowTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - <-time.After(5 * time.Millisecond) - return nil - })}, - // panic slow task - &SlowTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - panic("slow task ") - })}, - &SlowTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - <-time.After(5 * time.Millisecond) - return nil - })}, - // panic slow task - &SlowTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - panic("slow task ") - })}, - } - - fastTasks := []Task{ - &FastTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - <-time.After(2 * time.Millisecond) - return nil - })}, - &FastTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - panic("fast task") - })}, - &FastTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - <-time.After(2 * time.Millisecond) - return nil - })}, - &FastTask{task: TaskFunc(func(ctx context.Context) error { - resultChan <- struct{}{} - panic("fast task") - })}, - } - go func() { - // chan may be closed - defer func() { - if r := recover(); r != nil { - // 发送任务时,chan被关闭,也算作执行中 - resultChan <- struct{}{} - } - }() - for _, task := range slowTasks { - ex.SlowQueue() <- task - } - }() - go func() { - // chan may be closed - defer func() { - if r := recover(); r != nil { - // 发送任务时,chan被关闭,也算作执行中 - resultChan <- struct{}{} - // t.Log(fmt.Errorf("%w:%#v", ErrTaskPoolIsStopped, r)) - } - }() - for _, task := range fastTasks { - ex.FastQueue() <- task - } - }() - - // 等待任务开始执行 - <-time.After(100 * time.Millisecond) - - tasks := ex.Stop() - - // 等待任务执行并回传信号 - <-time.After(100 * time.Millisecond) +type ShutdownNowResult struct { + tasks []Task + err error +} - // 统计执行的任务数 - for ex.NumRunningFast() != 0 || ex.NumRunningSlow() != 0 { - time.Sleep(time.Millisecond) - } - close(resultChan) - num := 0 - for r := range resultChan { - if r == struct{}{} { - num++ - } - } +func testNewRunningStateTaskPool(t *testing.T, concurrency int, queueSize int) *BlockQueueTaskPool { + pool, _ := NewBlockQueueTaskPool(concurrency, queueSize) + assert.Equal(t, stateCreated, pool.internalState()) + assert.NoError(t, pool.Start()) + assert.Equal(t, stateRunning, pool.internalState()) + return pool +} - assert.Equal(t, n, len(slowTasks)+len(fastTasks)) - assert.Equal(t, len(slowTasks)+len(fastTasks), num+len(tasks)) - }) +type FakeTask struct{} -} +func (f *FakeTask) Run(ctx context.Context) error { return nil } From f2649ff89c827d3f7de13eb59640ba7b59826e9a Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Mon, 29 Aug 2022 01:59:42 +0800 Subject: [PATCH 08/15] =?UTF-8?q?=E4=BF=AE=E6=94=B9.CHANGELOG.md?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- .CHANGELOG.md | 3 ++- pool/task_pool_test.go | 2 +- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index aafe770c..e2c8d905 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -25,4 +25,5 @@ - [Len,Cap,NewLinkedList](https://github.com/gotomicro/ekit/pull/51) - [Range](https://github.com/gotomicro/ekit/pull/46) - [pool: 修复 Pool TestPool 测试不稳定的问题](https://github.com/gotomicro/ekit/pull/40) -- [ekit:引入 golangci-lint 和 goimports](https://github.com/gotomicro/ekit/pull/54) \ No newline at end of file +- [ekit:引入 golangci-lint 和 goimports](https://github.com/gotomicro/ekit/pull/54) +- [ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/pull/57) \ No newline at end of file diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 582b9874..8973cee9 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -571,4 +571,4 @@ func testNewRunningStateTaskPool(t *testing.T, concurrency int, queueSize int) * type FakeTask struct{} -func (f *FakeTask) Run(ctx context.Context) error { return nil } +func (f *FakeTask) Run(_ context.Context) error { return nil } From 79794150ba0d23218208a372f4e1b41bcd9e6ba5 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Mon, 29 Aug 2022 12:31:21 +0800 Subject: [PATCH 09/15] =?UTF-8?q?1.=E5=B0=86time.After=E6=9B=BF=E6=8D=A2?= =?UTF-8?q?=E4=B8=BAtime.Sleep=202.=E5=B0=86Task=E8=BF=90=E8=A1=8C?= =?UTF-8?q?=E6=97=B6=E6=8A=9B=E5=87=BA=E7=9A=84panic=E6=8E=A5=E4=BD=8F?= =?UTF-8?q?=E5=B9=B6=E5=B0=81=E8=A3=85=E4=B8=BAerror=E8=BF=94=E5=9B=9E?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pool/task_pool.go | 5 +++-- pool/task_pool_test.go | 29 +++++++++++++++-------------- 2 files changed, 18 insertions(+), 16 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index f8634309..b1874603 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -36,6 +36,7 @@ var ( errTaskPoolIsClosing = errors.New("ekit:TaskPool关闭中") errTaskPoolIsStopped = errors.New("ekit: TaskPool已停止") errTaskIsInvalid = errors.New("ekit: Task非法") + errTaskRunningPanic = errors.New("ekit: Task运行时异常") errInvalidArgument = errors.New("ekit: 参数非法") @@ -86,13 +87,13 @@ type taskWrapper struct { t Task } -func (tw *taskWrapper) Run(ctx context.Context) error { +func (tw *taskWrapper) Run(ctx context.Context) (err error) { defer func() { // 处理 panic if r := recover(); r != nil { buf := make([]byte, panicBuffLen) buf = buf[:runtime.Stack(buf, false)] - fmt.Printf("[PANIC]:\t%+v\n%s\n", r, buf) + err = fmt.Errorf("%w:%s", errTaskRunningPanic, fmt.Sprintf("[PANIC]:\t%+v\n%s\n", r, buf)) } }() return tw.t.Run(ctx) diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 8973cee9..d66e4857 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -127,7 +127,7 @@ func TestTaskPool_Submit(t *testing.T) { // 与下方 testSubmitBlockingAndTimeout 并发执行 errChan := make(chan error) go func() { - <-time.After(1 * time.Millisecond) + time.Sleep(1 * time.Millisecond) errChan <- pool.Start() }() @@ -150,7 +150,7 @@ func TestTaskPool_Submit(t *testing.T) { for i := 0; i < n; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-time.After(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) return nil })) if err != nil { @@ -166,7 +166,7 @@ func TestTaskPool_Submit(t *testing.T) { } resultChan := make(chan ShutdownResult) go func() { - <-time.After(time.Millisecond) + time.Sleep(time.Millisecond) done, err := pool.Shutdown() resultChan <- ShutdownResult{done: done, err: err} }() @@ -180,7 +180,8 @@ func TestTaskPool_Submit(t *testing.T) { assert.NoError(t, r.err) <-r.done // 等待状态迁移完成,并最终进入StateStopped状态 - <-time.After(100 * time.Millisecond) + time.Sleep(100 * time.Millisecond) + assert.Equal(t, stateStopped, pool.internalState()) }) @@ -195,7 +196,7 @@ func TestTaskPool_Submit(t *testing.T) { for i := 0; i < n; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-time.After(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) return nil })) if err != nil { @@ -208,7 +209,7 @@ func TestTaskPool_Submit(t *testing.T) { result := make(chan ShutdownNowResult, 1) go func() { - <-time.After(time.Millisecond) + time.Sleep(time.Millisecond) tasks, err := pool.ShutdownNow() result <- ShutdownNowResult{tasks: tasks, err: err} }() @@ -387,7 +388,7 @@ func TestTaskPool__Closing_(t *testing.T) { for i := 0; i < num; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-time.After(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) return nil })) t.Log(err) @@ -404,7 +405,7 @@ func TestTaskPool__Closing_(t *testing.T) { } <-done - <-time.After(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) assert.Equal(t, stateStopped, pool.internalState()) }) @@ -418,7 +419,7 @@ func TestTaskPool__Closing_(t *testing.T) { for i := 0; i < num; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-time.After(10 * time.Millisecond) + time.Sleep(10 * time.Millisecond) return nil })) t.Log(err) @@ -437,7 +438,7 @@ func TestTaskPool__Closing_(t *testing.T) { } <-done - <-time.After(50 * time.Millisecond) + time.Sleep(50 * time.Millisecond) assert.Equal(t, stateStopped, pool.internalState()) }) @@ -454,14 +455,14 @@ func TestTestPool__Stopped_(t *testing.T) { go func() { t.Log(pool.NumGo()) err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-time.After(time.Second) + time.Sleep(time.Second) return nil })) t.Log(err) }() } - <-time.After(100 * time.Millisecond) + time.Sleep(100 * time.Millisecond) assert.Equal(t, int32(concurrency), pool.NumGo()) tasks, err := pool.ShutdownNow() @@ -508,7 +509,7 @@ func TestTestPool__Stopped_(t *testing.T) { func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - <-time.After(2 * time.Millisecond) + time.Sleep(2 * time.Millisecond) return nil })) assert.NoError(t, err) @@ -520,7 +521,7 @@ func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { ctx, cancel := context.WithTimeout(context.Background(), time.Millisecond) defer cancel() err := pool.Submit(ctx, TaskFunc(func(ctx context.Context) error { - <-time.After(2 * time.Millisecond) + time.Sleep(2 * time.Millisecond) return nil })) if err != nil { From 63b42652232c8591c0e7bb8661725d45107db1ca Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Mon, 29 Aug 2022 16:35:49 +0800 Subject: [PATCH 10/15] =?UTF-8?q?1.=E6=98=BE=E7=A4=BA=E6=A3=80=E6=9F=A5que?= =?UTF-8?q?ue=E6=98=AF=E5=90=A6=E8=A2=AB=E5=85=B3=E9=97=AD=202.=E5=A4=84?= =?UTF-8?q?=E7=90=86Start/Shutdown/ShutdownNow()=E9=87=8D=E5=A4=8D?= =?UTF-8?q?=E8=B0=83=E7=94=A8=E9=97=AE=E9=A2=98=203.Submit=E6=96=B9?= =?UTF-8?q?=E6=B3=95=E4=B8=AD=E5=8F=AA=E6=A3=80=E9=AA=8Ctask=3D=3Dnil?= =?UTF-8?q?=E7=9A=84=E6=83=85=E5=86=B5?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pool/task_pool.go | 22 +++++------- pool/task_pool_test.go | 82 ++++++++---------------------------------- 2 files changed, 22 insertions(+), 82 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index b1874603..39c59c09 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -18,7 +18,6 @@ import ( "context" "errors" "fmt" - "reflect" "runtime" "sync" "sync/atomic" @@ -35,6 +34,7 @@ var ( errTaskPoolIsNotRunning = errors.New("ekit: TaskPool未运行") errTaskPoolIsClosing = errors.New("ekit:TaskPool关闭中") errTaskPoolIsStopped = errors.New("ekit: TaskPool已停止") + errTaskPoolIsStarted = errors.New("ekit:TaskPool已运行") errTaskIsInvalid = errors.New("ekit: Task非法") errTaskRunningPanic = errors.New("ekit: Task运行时异常") @@ -143,7 +143,7 @@ func NewBlockQueueTaskPool(concurrency int, queueSize int) (*BlockQueueTaskPool, // 如果因为 ctx 的原因返回,那么将会返回 ctx.Err() // 在调用 Start 前后都可以调用 Submit func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error { - if task == nil || reflect.ValueOf(task).IsNil() { + if task == nil { return fmt.Errorf("%w", errTaskIsInvalid) } // todo: 用户未设置超时,可以考虑内部给个超时提交 @@ -207,8 +207,7 @@ func (b *BlockQueueTaskPool) Start() error { } if atomic.LoadInt32(&b.state) == stateRunning { - // 重复调用,不予处理 - return nil + return fmt.Errorf("%w", errTaskPoolIsStarted) } if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateRunning) { @@ -227,9 +226,8 @@ func (b *BlockQueueTaskPool) startTasks() { return case b.token <- struct{}{}: - task := <-b.queue - // handle close(b.queue) - if task == nil { + task, ok := <-b.queue + if !ok { return } @@ -270,8 +268,7 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { } if atomic.LoadInt32(&b.state) == stateClosing { - // 重复调用 - return b.done, nil + return nil, fmt.Errorf("%w", errTaskPoolIsClosing) } if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateClosing) { @@ -312,12 +309,9 @@ func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { } if atomic.LoadInt32(&b.state) == stateStopped { - // 重复调用,返回缓存结果 - b.mux.RLock() - tasks := append([]Task(nil), b.submittedTasks...) - b.mux.RUnlock() - return tasks, nil + return nil, fmt.Errorf("%w", errTaskPoolIsStopped) } + if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateStopped) { // 目标:立刻关闭并且返回所有剩下未执行的任务 // 策略:关闭等待队列不再接受新任务,中断任务启动循环,清空等待队列并保存返回 diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index d66e4857..f1271cea 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -17,7 +17,6 @@ package pool import ( "context" "errors" - "sync" "testing" "time" @@ -82,27 +81,12 @@ func TestTaskPool_Start(t *testing.T) { pool, _ := NewBlockQueueTaskPool(1, 1) assert.Equal(t, stateCreated, pool.internalState()) - n := 5 - errChan := make(chan error, n) - // 第一次调用 assert.NoError(t, pool.Start()) assert.Equal(t, stateRunning, pool.internalState()) - // 多次调用 - var wg sync.WaitGroup - for i := 0; i < n; i++ { - wg.Add(1) - go func() { - errChan <- pool.Start() - wg.Done() - }() - } - wg.Wait() - close(errChan) - for err := range errChan { - assert.NoError(t, err) - } + // 重复调用 + assert.ErrorIs(t, pool.Start(), errTaskPoolIsStarted) assert.Equal(t, stateRunning, pool.internalState()) } @@ -240,8 +224,8 @@ func TestTaskPool_Shutdown(t *testing.T) { default: // 第二次调用 done2, err2 := pool.Shutdown() - assert.Equal(t, done2, done) - assert.Equal(t, err2, err) + assert.Nil(t, done2) + assert.ErrorIs(t, err2, errTaskPoolIsClosing) assert.Equal(t, stateClosing, pool.internalState()) } @@ -261,22 +245,15 @@ func TestTestPool_ShutdownNow(t *testing.T) { pool := testNewRunningStateTaskPool(t, 1, 1) - n := 3 - c := make(chan ShutdownNowResult, n) - - for i := 0; i < n; i++ { - go func() { - tasks, er := pool.ShutdownNow() - c <- ShutdownNowResult{tasks: tasks, err: er} - }() - } + tasks, err := pool.ShutdownNow() + assert.Nil(t, tasks) + assert.NoError(t, err) + assert.Equal(t, stateStopped, pool.internalState()) - for i := 0; i < n; i++ { - r := <-c - assert.Nil(t, r.tasks) - assert.NoError(t, r.err) - assert.Equal(t, stateStopped, pool.internalState()) - } + tasks, err = pool.ShutdownNow() + assert.Nil(t, tasks) + assert.ErrorIs(t, err, errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) } func TestTaskPool__Created_(t *testing.T) { @@ -295,7 +272,7 @@ func TestTaskPool__Created_(t *testing.T) { pool, _ := NewBlockQueueTaskPool(1, 1) assert.Equal(t, stateCreated, pool.internalState()) - testSubmitInvalidTask(t, pool) + assert.ErrorIs(t, pool.Submit(context.Background(), nil), errTaskIsInvalid) assert.Equal(t, stateCreated, pool.internalState()) }) @@ -341,21 +318,13 @@ func TestTaskPool__Created_(t *testing.T) { func TestTaskPool__Running_(t *testing.T) { t.Parallel() - t.Run("Start", func(t *testing.T) { - t.Parallel() - - pool := testNewRunningStateTaskPool(t, 1, 3) - assert.NoError(t, pool.Start()) - assert.Equal(t, stateRunning, pool.internalState()) - }) - t.Run("Submit", func(t *testing.T) { t.Parallel() t.Run("提交非法Task", func(t *testing.T) { t.Parallel() pool := testNewRunningStateTaskPool(t, 1, 1) - testSubmitInvalidTask(t, pool) + assert.ErrorIs(t, pool.Submit(context.Background(), nil), errTaskIsInvalid) assert.Equal(t, stateRunning, pool.internalState()) }) @@ -493,17 +462,6 @@ func TestTestPool__Stopped_(t *testing.T) { assert.ErrorIs(t, err, errTaskPoolIsStopped) assert.Equal(t, stateStopped, pool.internalState()) }) - - t.Run("ShutdownNow", func(t *testing.T) { - t.Parallel() - - // 多次调用返回相同结果 - tasks2, err := pool.ShutdownNow() - assert.NoError(t, err) - assert.NotEmpty(t, tasks2) - assert.Equal(t, tasks2, tasks) - assert.Equal(t, stateStopped, pool.internalState()) - }) } func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { @@ -545,18 +503,6 @@ func testSubmitValidTask(t *testing.T, pool *BlockQueueTaskPool) { assert.NoError(t, err) } -func testSubmitInvalidTask(t *testing.T, pool *BlockQueueTaskPool) { - - invalidTasks := map[string]Task{"*FakeTask": (*FakeTask)(nil), "nil": nil, "TaskFunc(nil)": TaskFunc(nil)} - - for name, task := range invalidTasks { - t.Run(name, func(t *testing.T) { - err := pool.Submit(context.Background(), task) - assert.ErrorIs(t, err, errTaskIsInvalid) - }) - } -} - type ShutdownNowResult struct { tasks []Task err error From 2f39abea4fd660c8e5ae5f72c25fa927c3ee153b Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Mon, 29 Aug 2022 22:01:59 +0800 Subject: [PATCH 11/15] =?UTF-8?q?1.=20=E5=9B=A0=E4=B8=BA=E8=B0=83=E6=95=B4?= =?UTF-8?q?=E4=B8=BA=E9=87=8D=E5=A4=8D=E8=B0=83=E7=94=A8=E5=90=8E=E7=9B=B4?= =?UTF-8?q?=E6=8E=A5=E6=8A=A5=E9=94=99=EF=BC=8C=E9=87=8D=E6=9E=84ShutdownN?= =?UTF-8?q?ow()=E5=B9=B6=E5=8E=BB=E6=8E=89submittedTasks=E5=92=8CSync.RWMu?= =?UTF-8?q?tex=202.=20=E9=87=8D=E6=9E=84Shutdown=EF=BC=88=EF=BC=89?= =?UTF-8?q?=E4=BD=BF=E7=94=A8sync.WaitGroup=E6=9B=BF=E4=BB=A3=E7=94=A8time?= =?UTF-8?q?.Sleep=E6=9D=A5=E6=A3=80=E6=B5=8Bb.num=3D=3D0=203.=20=E5=B0=86S?= =?UTF-8?q?tartTasks=E9=87=8D=E5=91=BD=E5=90=8D=E4=B8=BAschedulingTasks?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pool/task_pool.go | 42 ++++++++++++++++-------------------------- pool/task_pool_test.go | 5 ++++- 2 files changed, 20 insertions(+), 27 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index 39c59c09..0c651fb7 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -21,7 +21,6 @@ import ( "runtime" "sync" "sync/atomic" - "time" ) var ( @@ -107,15 +106,13 @@ type BlockQueueTaskPool struct { queue chan Task token chan struct{} num int32 + wg sync.WaitGroup // 外部信号 done chan struct{} // 内部中断信号 ctx context.Context cancelFunc context.CancelFunc - // 缓存 - mux sync.RWMutex - submittedTasks []Task } // NewBlockQueueTaskPool 创建一个新的 BlockQueueTaskPool @@ -211,13 +208,14 @@ func (b *BlockQueueTaskPool) Start() error { } if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateRunning) { - go b.startTasks() + go b.schedulingTasks() return nil } } } -func (b *BlockQueueTaskPool) startTasks() { +// Schedule tasks +func (b *BlockQueueTaskPool) schedulingTasks() { defer close(b.token) for { @@ -228,14 +226,23 @@ func (b *BlockQueueTaskPool) startTasks() { task, ok := <-b.queue if !ok { + // 调用Shutdown后,TaskPool处于Closing状态 + if atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped) { + // 等待运行中的Task自然结束 + b.wg.Wait() + // 通知外部调用者 + close(b.done) + } return } - go func() { + b.wg.Add(1) + atomic.AddInt32(&b.num, 1) - atomic.AddInt32(&b.num, 1) + go func() { defer func() { atomic.AddInt32(&b.num, -1) + b.wg.Done() <-b.token }() @@ -262,8 +269,6 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { } if atomic.LoadInt32(&b.state) == stateStopped { - // 重复调用时,恰好前一个Shutdown调用将状态迁移为StateStopped - // 这种情况与先调用ShutdownNow状态迁移为StateStopped再调用Shutdown效果一样 return nil, fmt.Errorf("%w", errTaskPoolIsStopped) } @@ -276,19 +281,8 @@ func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) { // 策略:先将队列中的任务启动并执行(清空队列),再等待全部运行中的任务自然退出 // 先关闭等待队列不再允许提交 - // 同时任务启动循环能够通过Task==nil来终止循环 + // 同时任务调度循环能够通过b.queue是否被关闭来终止循环 close(b.queue) - - go func() { - // 等待运行中的Task自然结束 - for atomic.LoadInt32(&b.num) != 0 { - time.Sleep(time.Second) - } - // 通知外部调用者 - close(b.done) - // 完成最终的状态迁移 - atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped) - }() return b.done, nil } @@ -321,15 +315,11 @@ func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { // 发送中断信号,中断任务启动循环 b.cancelFunc() - b.mux.Lock() // 清空队列并保存 var tasks []Task for task := range b.queue { - b.submittedTasks = append(b.submittedTasks, task) tasks = append(tasks, task) } - b.mux.Unlock() - return tasks, nil } } diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index f1271cea..a4bbd4bf 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -124,6 +124,7 @@ func TestTaskPool_Submit(t *testing.T) { }) t.Run("TaskPool状态由Running变为Closing", func(t *testing.T) { + // t.Skip() t.Parallel() pool := testNewRunningStateTaskPool(t, 1, 2) @@ -154,7 +155,6 @@ func TestTaskPool_Submit(t *testing.T) { done, err := pool.Shutdown() resultChan <- ShutdownResult{done: done, err: err} }() - r := <-resultChan // 阻塞中的任务报错,证明处于TaskPool处于StateClosing状态 @@ -323,6 +323,7 @@ func TestTaskPool__Running_(t *testing.T) { t.Run("提交非法Task", func(t *testing.T) { t.Parallel() + pool := testNewRunningStateTaskPool(t, 1, 1) assert.ErrorIs(t, pool.Submit(context.Background(), nil), errTaskIsInvalid) assert.Equal(t, stateRunning, pool.internalState()) @@ -330,6 +331,7 @@ func TestTaskPool__Running_(t *testing.T) { t.Run("正常提交Task", func(t *testing.T) { t.Parallel() + pool := testNewRunningStateTaskPool(t, 1, 3) testSubmitValidTask(t, pool) assert.Equal(t, stateRunning, pool.internalState()) @@ -337,6 +339,7 @@ func TestTaskPool__Running_(t *testing.T) { t.Run("阻塞提交并导致超时", func(t *testing.T) { t.Parallel() + pool := testNewRunningStateTaskPool(t, 1, 1) testSubmitBlockingAndTimeout(t, pool) assert.Equal(t, stateRunning, pool.internalState()) From 91d37414e7007617dd2bc82836bc53a18a55eba3 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Mon, 29 Aug 2022 23:52:36 +0800 Subject: [PATCH 12/15] =?UTF-8?q?=E6=B7=BB=E5=8A=A0=E6=B3=A8=E9=87=8A?= =?UTF-8?q?=E4=BF=AE=E6=94=B9=E4=BB=A3=E7=A0=81=E6=8F=90=E9=AB=98=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- pool/task_pool.go | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pool/task_pool.go b/pool/task_pool.go index 0c651fb7..83380b93 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -182,9 +182,9 @@ func (b *BlockQueueTaskPool) trySubmit(ctx context.Context, task Task, state int case b.queue <- task: return true, nil default: - // 不能阻塞在临界区 + // 不能阻塞在临界区,要给Shutdown和ShutdownNow机会 + return false, nil } - return false, nil } return false, nil } From 9a195e80dcb556da231d0058dfad70a5be41295d Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Thu, 25 Aug 2022 13:59:52 +0800 Subject: [PATCH 13/15] =?UTF-8?q?1.task=5Fpool.go=20=E4=BF=AE=E6=94=B9task?= =?UTF-8?q?=E7=9A=84=E5=A3=B0=E6=98=8E=E6=96=B9=E5=BC=8F=EF=BC=8C=E6=98=BE?= =?UTF-8?q?=E7=A4=BA=E6=8C=87=E5=AE=9A=E5=85=B6=E5=AE=B9=E9=87=8F=202.task?= =?UTF-8?q?=5Fpool=5Ftest.go=20=E9=87=8D=E6=96=B0=E7=BB=84=E7=BB=87?= =?UTF-8?q?=E6=B5=8B=E8=AF=95=E4=BD=BF=E5=85=B6=E6=9B=B4=E5=85=B7=E5=8F=AF?= =?UTF-8?q?=E8=AF=BB=E6=80=A7=EF=BC=8C=E6=B8=85=E7=90=86=E9=AD=94=E6=95=B0?= =?UTF-8?q?=E6=98=BE=E7=A4=BA=E8=A1=A8=E6=98=8E=E5=85=B6=E6=84=8F=E5=9B=BE?= =?UTF-8?q?=E7=AD=89?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Signed-off-by: longyue0521 --- .CHANGELOG.md | 2 +- pool/task_pool.go | 2 +- pool/task_pool_test.go | 435 ++++++++++++++++++----------------------- 3 files changed, 189 insertions(+), 250 deletions(-) diff --git a/.CHANGELOG.md b/.CHANGELOG.md index 07bb96ac..1294ff44 100644 --- a/.CHANGELOG.md +++ b/.CHANGELOG.md @@ -28,4 +28,4 @@ - [Range](https://github.com/gotomicro/ekit/pull/46) - [pool: 修复 Pool TestPool 测试不稳定的问题](https://github.com/gotomicro/ekit/pull/40) - [ekit:引入 golangci-lint 和 goimports](https://github.com/gotomicro/ekit/pull/54) -- [ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/pull/57) +- [ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/pull/57) \ No newline at end of file diff --git a/pool/task_pool.go b/pool/task_pool.go index 83380b93..4989a2af 100644 --- a/pool/task_pool.go +++ b/pool/task_pool.go @@ -316,7 +316,7 @@ func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) { b.cancelFunc() // 清空队列并保存 - var tasks []Task + tasks := make([]Task, 0, len(b.queue)) for task := range b.queue { tasks = append(tasks, task) } diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index a4bbd4bf..7c988261 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -36,233 +36,38 @@ New() --> CREATED -- Start() ---> RUNNING -- -- ShutdownNow() ---> STOPPED -- ShutdownNow() --> STOPPED */ -func TestTaskPool_New(t *testing.T) { +func TestTaskPool_In_Created_State(t *testing.T) { t.Parallel() - pool, err := NewBlockQueueTaskPool(1, -1) - assert.ErrorIs(t, err, errInvalidArgument) - assert.Nil(t, pool) - - pool, err = NewBlockQueueTaskPool(1, 0) - assert.NoError(t, err) - assert.NotNil(t, pool) - - pool, err = NewBlockQueueTaskPool(1, 1) - assert.NoError(t, err) - assert.NotNil(t, pool) - - pool, err = NewBlockQueueTaskPool(-1, 1) - assert.ErrorIs(t, err, errInvalidArgument) - assert.Nil(t, pool) - - pool, err = NewBlockQueueTaskPool(0, 1) - assert.ErrorIs(t, err, errInvalidArgument) - assert.Nil(t, pool) - - pool, err = NewBlockQueueTaskPool(1, 1) - assert.NoError(t, err) - assert.NotNil(t, pool) - -} - -func TestTaskPool_Start(t *testing.T) { - /* - todo: Start - 1. happy - [x] change state from CREATED to RUNNING - done - - 非阻塞 task 调度器开始工作,开始执行工作 - - [x] Start多次,保证只运行一次,或者报错——TaskPool已经启动 - 2. sad - - CLOSING state -> start error,多次运行结果一致 - STOPPED state -> start error多次运行结果一致 - */ - - t.Parallel() - - pool, _ := NewBlockQueueTaskPool(1, 1) - assert.Equal(t, stateCreated, pool.internalState()) - - // 第一次调用 - assert.NoError(t, pool.Start()) - assert.Equal(t, stateRunning, pool.internalState()) - - // 重复调用 - assert.ErrorIs(t, pool.Start(), errTaskPoolIsStarted) - assert.Equal(t, stateRunning, pool.internalState()) -} - -func TestTaskPool_Submit(t *testing.T) { - t.Parallel() - // todo: Submit - // 在TaskPool所有状态中都可以提交,有的成功/阻塞,有的立即失败。 - // [x]ctx时间内提交成功 - // [x]ctx超时,提交失败给出错误信息 - // [x] 监听状态变化,从running->closing/stopped - // [x] Shutdown后,状态变迁,需要检查出并报错 errTaskPoolIsClosing - // [x] ShutdownNow后状态变迁,需要检查出并报错,errTaskPoolIsStopped - - t.Run("提交Task阻塞", func(t *testing.T) { + t.Run("New", func(t *testing.T) { t.Parallel() - t.Run("TaskPool状态由Created变为Running", func(t *testing.T) { - t.Parallel() - - pool, _ := NewBlockQueueTaskPool(1, 1) + pool, err := NewBlockQueueTaskPool(1, -1) + assert.ErrorIs(t, err, errInvalidArgument) + assert.Nil(t, pool) - // 与下方 testSubmitBlockingAndTimeout 并发执行 - errChan := make(chan error) - go func() { - time.Sleep(1 * time.Millisecond) - errChan <- pool.Start() - }() - - assert.Equal(t, stateCreated, pool.internalState()) - - testSubmitBlockingAndTimeout(t, pool) - - assert.NoError(t, <-errChan) - assert.Equal(t, stateRunning, pool.internalState()) - }) - - t.Run("TaskPool状态由Running变为Closing", func(t *testing.T) { - // t.Skip() - t.Parallel() - - pool := testNewRunningStateTaskPool(t, 1, 2) - - // 模拟阻塞提交 - n := 10 - firstSubmitErrChan := make(chan error, 1) - for i := 0; i < n; i++ { - go func() { - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - time.Sleep(10 * time.Millisecond) - return nil - })) - if err != nil { - firstSubmitErrChan <- err - } - }() - } - - // 调用Shutdown使TaskPool状态发生迁移 - type ShutdownResult struct { - done <-chan struct{} - err error - } - resultChan := make(chan ShutdownResult) - go func() { - time.Sleep(time.Millisecond) - done, err := pool.Shutdown() - resultChan <- ShutdownResult{done: done, err: err} - }() - r := <-resultChan - - // 阻塞中的任务报错,证明处于TaskPool处于StateClosing状态 - assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsClosing) - - // Shutdown调用成功 - assert.NoError(t, r.err) - <-r.done - // 等待状态迁移完成,并最终进入StateStopped状态 - time.Sleep(100 * time.Millisecond) - - assert.Equal(t, stateStopped, pool.internalState()) - }) - - t.Run("TaskPool状态由Running变为Stopped", func(t *testing.T) { - t.Parallel() - - pool := testNewRunningStateTaskPool(t, 1, 2) - - // 模拟阻塞提交 - n := 5 - firstSubmitErrChan := make(chan error, 1) - for i := 0; i < n; i++ { - go func() { - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - time.Sleep(10 * time.Millisecond) - return nil - })) - if err != nil { - firstSubmitErrChan <- err - } - }() - } - - // 并发调用ShutdownNow + pool, err = NewBlockQueueTaskPool(1, 0) + assert.NoError(t, err) + assert.NotNil(t, pool) - result := make(chan ShutdownNowResult, 1) - go func() { - time.Sleep(time.Millisecond) - tasks, err := pool.ShutdownNow() - result <- ShutdownNowResult{tasks: tasks, err: err} - }() + pool, err = NewBlockQueueTaskPool(1, 1) + assert.NoError(t, err) + assert.NotNil(t, pool) - r := <-result - assert.NoError(t, r.err) - assert.NotEmpty(t, r.tasks) + pool, err = NewBlockQueueTaskPool(-1, 1) + assert.ErrorIs(t, err, errInvalidArgument) + assert.Nil(t, pool) - assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsStopped) - assert.Equal(t, stateStopped, pool.internalState()) - }) + pool, err = NewBlockQueueTaskPool(0, 1) + assert.ErrorIs(t, err, errInvalidArgument) + assert.Nil(t, pool) + pool, err = NewBlockQueueTaskPool(1, 1) + assert.NoError(t, err) + assert.NotNil(t, pool) }) -} - -func TestTaskPool_Shutdown(t *testing.T) { - t.Parallel() - pool := testNewRunningStateTaskPool(t, 1, 1) - - // 第一次调用 - done, err := pool.Shutdown() - assert.NoError(t, err) - - select { - case <-done: - break - default: - // 第二次调用 - done2, err2 := pool.Shutdown() - assert.Nil(t, done2) - assert.ErrorIs(t, err2, errTaskPoolIsClosing) - assert.Equal(t, stateClosing, pool.internalState()) - } - - <-done - assert.Equal(t, stateStopped, pool.internalState()) - - // 第一个Shutdown将状态迁移至StateStopped - // 第三次调用 - done, err = pool.Shutdown() - assert.Nil(t, done) - assert.ErrorIs(t, err, errTaskPoolIsStopped) -} - -func TestTestPool_ShutdownNow(t *testing.T) { - - t.Parallel() - - pool := testNewRunningStateTaskPool(t, 1, 1) - - tasks, err := pool.ShutdownNow() - assert.Nil(t, tasks) - assert.NoError(t, err) - assert.Equal(t, stateStopped, pool.internalState()) - - tasks, err = pool.ShutdownNow() - assert.Nil(t, tasks) - assert.ErrorIs(t, err, errTaskPoolIsStopped) - assert.Equal(t, stateStopped, pool.internalState()) -} - -func TestTaskPool__Created_(t *testing.T) { - t.Parallel() - - pool, err := NewBlockQueueTaskPool(1, 1) - assert.NoError(t, err) - assert.NotNil(t, pool) - assert.Equal(t, stateCreated, pool.internalState()) + // Start()导致TaskPool状态迁移,测试见TestTaskPool_In_Running_State/Start t.Run("Submit", func(t *testing.T) { t.Parallel() @@ -298,6 +103,10 @@ func TestTaskPool__Created_(t *testing.T) { t.Run("Shutdown", func(t *testing.T) { t.Parallel() + pool, err := NewBlockQueueTaskPool(1, 1) + assert.NoError(t, err) + assert.Equal(t, stateCreated, pool.internalState()) + done, err := pool.Shutdown() assert.Nil(t, done) assert.ErrorIs(t, err, errTaskPoolIsNotRunning) @@ -307,17 +116,44 @@ func TestTaskPool__Created_(t *testing.T) { t.Run("ShutdownNow", func(t *testing.T) { t.Parallel() + pool, err := NewBlockQueueTaskPool(1, 1) + assert.NoError(t, err) + assert.Equal(t, stateCreated, pool.internalState()) + tasks, err := pool.ShutdownNow() assert.Nil(t, tasks) assert.ErrorIs(t, err, errTaskPoolIsNotRunning) assert.Equal(t, stateCreated, pool.internalState()) }) - } -func TestTaskPool__Running_(t *testing.T) { +func TestTaskPool_In_Running_State(t *testing.T) { t.Parallel() + t.Run("Start —— 使TaskPool状态由Created变为Running", func(t *testing.T) { + t.Parallel() + + pool, _ := NewBlockQueueTaskPool(1, 1) + + // 与下方 testSubmitBlockingAndTimeout 并发执行 + errChan := make(chan error) + go func() { + time.Sleep(1 * time.Millisecond) + errChan <- pool.Start() + }() + + assert.Equal(t, stateCreated, pool.internalState()) + + testSubmitBlockingAndTimeout(t, pool) + + assert.NoError(t, <-errChan) + assert.Equal(t, stateRunning, pool.internalState()) + + // 重复调用 + assert.ErrorIs(t, pool.Start(), errTaskPoolIsStarted) + assert.Equal(t, stateRunning, pool.internalState()) + }) + t.Run("Submit", func(t *testing.T) { t.Parallel() @@ -345,25 +181,91 @@ func TestTaskPool__Running_(t *testing.T) { assert.Equal(t, stateRunning, pool.internalState()) }) }) + + // Shutdown()导致TaskPool状态迁移,TestTaskPool_In_Closing_State/Shutdown + + // ShutdownNow()导致TaskPool状态迁移,TestTestPool_In_Stopped_State/ShutdownNow } -func TestTaskPool__Closing_(t *testing.T) { +func TestTaskPool_In_Closing_State(t *testing.T) { t.Parallel() + t.Run("Shutdown —— 使TaskPool状态由Running变为Closing", func(t *testing.T) { + t.Parallel() + + queueSize := 2 + pool := testNewRunningStateTaskPool(t, 1, queueSize) + + // 模拟阻塞提交 + n := queueSize * 5 + firstSubmitErrChan := make(chan error, 1) + for i := 0; i < n; i++ { + go func() { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + time.Sleep(10 * time.Millisecond) + return nil + })) + if err != nil { + firstSubmitErrChan <- err + } + }() + } + + // 调用Shutdown使TaskPool状态发生迁移 + type ShutdownResult struct { + done <-chan struct{} + err error + } + resultChan := make(chan ShutdownResult) + go func() { + time.Sleep(time.Millisecond) + done, err := pool.Shutdown() + resultChan <- ShutdownResult{done: done, err: err} + }() + r := <-resultChan + + // Closing过程中Submit会报错间接证明TaskPool处于StateClosing状态 + assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsClosing) + + // Shutdown调用成功 + assert.NoError(t, r.err) + select { + case <-r.done: + break + default: + // 第二次调用 + done2, err2 := pool.Shutdown() + assert.Nil(t, done2) + assert.ErrorIs(t, err2, errTaskPoolIsClosing) + assert.Equal(t, stateClosing, pool.internalState()) + } + + <-r.done + assert.Equal(t, stateStopped, pool.internalState()) + + // 第一个Shutdown将状态迁移至StateStopped + // 第三次调用 + done, err := pool.Shutdown() + assert.Nil(t, done) + assert.ErrorIs(t, err, errTaskPoolIsStopped) + }) + t.Run("Start", func(t *testing.T) { t.Parallel() - num := 10 - pool := testNewRunningStateTaskPool(t, 2, num) + queueSize := 10 + pool := testNewRunningStateTaskPool(t, 2, queueSize) // 提交任务 - for i := 0; i < num; i++ { + for i := 0; i < queueSize; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { time.Sleep(10 * time.Millisecond) return nil })) - t.Log(err) + if err != nil { + return + } }() } @@ -377,24 +279,28 @@ func TestTaskPool__Closing_(t *testing.T) { } <-done - time.Sleep(10 * time.Millisecond) assert.Equal(t, stateStopped, pool.internalState()) }) + // Submit()在状态中会报错,因为Closing是一个中间状态,故在上面的Shutdown间接测到了 + t.Run("ShutdownNow", func(t *testing.T) { t.Parallel() - pool := testNewRunningStateTaskPool(t, 1, 0) + concurrency := 2 + pool := testNewRunningStateTaskPool(t, concurrency, 0) // 提交任务 - num := 10 + num := concurrency * 5 for i := 0; i < num; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { time.Sleep(10 * time.Millisecond) return nil })) - t.Log(err) + if err != nil { + return + } }() } @@ -410,41 +316,64 @@ func TestTaskPool__Closing_(t *testing.T) { } <-done - time.Sleep(50 * time.Millisecond) assert.Equal(t, stateStopped, pool.internalState()) }) - } -func TestTestPool__Stopped_(t *testing.T) { +func TestTestPool_In_Stopped_State(t *testing.T) { t.Parallel() - concurrency, n := 2, 20 - pool := testNewRunningStateTaskPool(t, concurrency, n) + t.Run("ShutdownNow —— 使TaskPool状态由Running变为Stopped", func(t *testing.T) { + t.Parallel() - // 模拟阻塞提交 - for i := 0; i < n; i++ { + concurrency, queueSize := 2, 4 + pool := testNewRunningStateTaskPool(t, concurrency, queueSize) + + // 模拟阻塞提交 + n := queueSize + 3 + firstSubmitErrChan := make(chan error, 1) + for i := 0; i < n; i++ { + go func() { + err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { + time.Sleep(10 * time.Millisecond) + return nil + })) + if err != nil { + firstSubmitErrChan <- err + } + }() + } + + time.Sleep(1 * time.Millisecond) + assert.Equal(t, int32(concurrency), pool.NumGo()) + + // 并发调用ShutdownNow + result := make(chan ShutdownNowResult, 1) go func() { - t.Log(pool.NumGo()) - err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { - time.Sleep(time.Second) - return nil - })) - t.Log(err) + time.Sleep(time.Millisecond) + tasks, err := pool.ShutdownNow() + result <- ShutdownNowResult{tasks: tasks, err: err} }() - } - time.Sleep(100 * time.Millisecond) - assert.Equal(t, int32(concurrency), pool.NumGo()) + r := <-result + assert.NoError(t, r.err) + assert.Equal(t, queueSize, len(r.tasks)) - tasks, err := pool.ShutdownNow() - assert.NoError(t, err) - assert.NotEmpty(t, tasks) - assert.Equal(t, stateStopped, pool.internalState()) + // 阻塞的Submit在ShutdownNow后会报错间接证明TaskPool处于StateStopped状态 + assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) + + // 重复调用 + tasks, err := pool.ShutdownNow() + assert.Nil(t, tasks) + assert.ErrorIs(t, err, errTaskPoolIsStopped) + assert.Equal(t, stateStopped, pool.internalState()) + }) t.Run("Start", func(t *testing.T) { t.Parallel() + pool := testNewStoppedStateTaskPool(t, 1, 1) assert.ErrorIs(t, pool.Start(), errTaskPoolIsStopped) assert.Equal(t, stateStopped, pool.internalState()) }) @@ -452,6 +381,7 @@ func TestTestPool__Stopped_(t *testing.T) { t.Run("Submit", func(t *testing.T) { t.Parallel() + pool := testNewStoppedStateTaskPool(t, 1, 1) err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { return nil })) assert.ErrorIs(t, err, errTaskPoolIsStopped) assert.Equal(t, stateStopped, pool.internalState()) @@ -460,6 +390,7 @@ func TestTestPool__Stopped_(t *testing.T) { t.Run("Shutdown", func(t *testing.T) { t.Parallel() + pool := testNewStoppedStateTaskPool(t, 1, 1) done, err := pool.Shutdown() assert.Nil(t, done) assert.ErrorIs(t, err, errTaskPoolIsStopped) @@ -475,7 +406,7 @@ func testSubmitBlockingAndTimeout(t *testing.T, pool *BlockQueueTaskPool) { })) assert.NoError(t, err) - n := 2 + n := len(pool.queue) + 1 errChan := make(chan error, n) for i := 0; i < n; i++ { go func() { @@ -519,6 +450,14 @@ func testNewRunningStateTaskPool(t *testing.T, concurrency int, queueSize int) * return pool } +func testNewStoppedStateTaskPool(t *testing.T, concurrency int, queueSize int) *BlockQueueTaskPool { + pool := testNewRunningStateTaskPool(t, concurrency, queueSize) + _, err := pool.ShutdownNow() + assert.NoError(t, err) + assert.Equal(t, stateStopped, pool.internalState()) + return pool +} + type FakeTask struct{} func (f *FakeTask) Run(_ context.Context) error { return nil } From 8f1ce93df9c84f34f504b2c7255d59c7c405f2d1 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Wed, 31 Aug 2022 02:17:42 +0800 Subject: [PATCH 14/15] =?UTF-8?q?=E6=94=B9=E4=B8=BA=E6=9B=B4=E5=AE=BD?= =?UTF-8?q?=E6=B3=9B=E7=9A=84=E6=96=AD=E8=A8=80?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pool/task_pool_test.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 7c988261..7f95c249 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -357,7 +357,7 @@ func TestTestPool_In_Stopped_State(t *testing.T) { r := <-result assert.NoError(t, r.err) - assert.Equal(t, queueSize, len(r.tasks)) + assert.Greater(t, len(r.tasks), 0) // 阻塞的Submit在ShutdownNow后会报错间接证明TaskPool处于StateStopped状态 assert.ErrorIs(t, <-firstSubmitErrChan, errTaskPoolIsStopped) From 2f91e077058b1bf579114f1109adfaf5122fd8b7 Mon Sep 17 00:00:00 2001 From: longyue0521 Date: Wed, 31 Aug 2022 02:35:53 +0800 Subject: [PATCH 15/15] =?UTF-8?q?=E4=BF=AE=E6=94=B9=E6=B5=8B=E8=AF=95?= =?UTF-8?q?=E7=94=A8=E4=BE=8B?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- pool/task_pool_test.go | 5 ++--- 1 file changed, 2 insertions(+), 3 deletions(-) diff --git a/pool/task_pool_test.go b/pool/task_pool_test.go index 7f95c249..f6d9b9f0 100644 --- a/pool/task_pool_test.go +++ b/pool/task_pool_test.go @@ -330,8 +330,8 @@ func TestTestPool_In_Stopped_State(t *testing.T) { pool := testNewRunningStateTaskPool(t, concurrency, queueSize) // 模拟阻塞提交 - n := queueSize + 3 - firstSubmitErrChan := make(chan error, 1) + n := queueSize + 6 + firstSubmitErrChan := make(chan error, concurrency) for i := 0; i < n; i++ { go func() { err := pool.Submit(context.Background(), TaskFunc(func(ctx context.Context) error { @@ -350,7 +350,6 @@ func TestTestPool_In_Stopped_State(t *testing.T) { // 并发调用ShutdownNow result := make(chan ShutdownNowResult, 1) go func() { - time.Sleep(time.Millisecond) tasks, err := pool.ShutdownNow() result <- ShutdownNowResult{tasks: tasks, err: err} }()