Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

pool: TaskPool 实现 #50 #57

Merged
merged 17 commits into from
Aug 31, 2022
Merged
Show file tree
Hide file tree
Changes from 11 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 13 additions & 0 deletions .CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,16 @@
# 开发中
[ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6)
[sql: 支持 JsonColumn](https://github.com/gotomicro/ekit/pull/7)
[ekit: 支持 ArrayList Range()](https://github.com/gotomicro/ekit/pull/12)
[ekit: 支持 ArrayList Get()](https://github.com/gotomicro/ekit/pull/18)
[ekit: 支持 ArrayList Len()](https://github.com/gotomicro/ekit/pull/18)
[ekit: 实现了 LinkedList Add() NewLinkedListOf](https://github.com/gotomicro/ekit/pull/26)
[ekit: 实现了 LinkedList Get](https://github.com/gotomicro/ekit/pull/31)
[ekit: 实现了 LinkedList Append](https://github.com/gotomicro/ekit/pull/34)
[ekit: 实现了 LinkedList Delete](https://github.com/gotomicro/ekit/pull/38)
[ekit: 修复 Pool TestPool](https://github.com/gotomicro/ekit/pull/40)
[ekit: 实现了 LinkedList Range](https://github.com/gotomicro/ekit/pull/46)
[ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/pull/57)
flycash marked this conversation as resolved.
Show resolved Hide resolved
flycash marked this conversation as resolved.
Show resolved Hide resolved
- [ekit: add ToPtr function](https://github.com/gotomicro/ekit/pull/6)
- [sql: 支持 JsonColumn](https://github.com/gotomicro/ekit/pull/7)
- [bean/copier: 实现了基于反射的 ReflectCopier](https://github.com/gotomicro/ekit/pull/47)
Expand All @@ -16,3 +28,4 @@
- [Range](https://github.com/gotomicro/ekit/pull/46)
- [pool: 修复 Pool TestPool 测试不稳定的问题](https://github.com/gotomicro/ekit/pull/40)
- [ekit:引入 golangci-lint 和 goimports](https://github.com/gotomicro/ekit/pull/54)
- [ekit: 实现了 TaskPool](https://github.com/gotomicro/ekit/pull/57)
283 changes: 270 additions & 13 deletions pool/task_pool.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,35 @@

package pool

import "context"
import (
"context"
"errors"
"fmt"
"reflect"
"runtime"
"sync"
"sync/atomic"
"time"
)

var (
stateCreated int32 = 1
stateRunning int32 = 2
stateClosing int32 = 3
stateStopped int32 = 4
stateLocked int32 = 5

errTaskPoolIsNotRunning = errors.New("ekit: TaskPool未运行")
errTaskPoolIsClosing = errors.New("ekit:TaskPool关闭中")
errTaskPoolIsStopped = errors.New("ekit: TaskPool已停止")
errTaskIsInvalid = errors.New("ekit: Task非法")
errTaskRunningPanic = errors.New("ekit: Task运行时异常")

errInvalidArgument = errors.New("ekit: 参数非法")

_ TaskPool = &BlockQueueTaskPool{}
panicBuffLen = 2048
)

// TaskPool 任务池
type TaskPool interface {
Expand Down Expand Up @@ -52,48 +80,277 @@ type TaskFunc func(ctx context.Context) error

// Run 执行任务
// 超时控制取决于衍生出 TaskFunc 的方法
func (t TaskFunc) Run(ctx context.Context) error {
return t(ctx)
func (t TaskFunc) Run(ctx context.Context) error { return t(ctx) }

// taskWrapper 是Task的装饰器
type taskWrapper struct {
t Task
}

func (tw *taskWrapper) Run(ctx context.Context) (err error) {
defer func() {
// 处理 panic
flycash marked this conversation as resolved.
Show resolved Hide resolved
if r := recover(); r != nil {
buf := make([]byte, panicBuffLen)
buf = buf[:runtime.Stack(buf, false)]
err = fmt.Errorf("%w:%s", errTaskRunningPanic, fmt.Sprintf("[PANIC]:\t%+v\n%s\n", r, buf))
}
}()
return tw.t.Run(ctx)
}

// BlockQueueTaskPool 并发阻塞的任务池
type BlockQueueTaskPool struct {
// TaskPool内部状态
state int32

queue chan Task
token chan struct{}
num int32

// 外部信号
done chan struct{}
// 内部中断信号
ctx context.Context
cancelFunc context.CancelFunc
// 缓存
mux sync.RWMutex
submittedTasks []Task
}

// NewBlockQueueTaskPool 创建一个新的 BlockQueueTaskPool
// concurrency 是并发数,即最多允许多少个 goroutine 执行任务
// queueSize 是队列大小,即最多有多少个任务在等待调度
func NewBlockQueueTaskPool(concurrency int, queueSize int) (*BlockQueueTaskPool, error) {
return &BlockQueueTaskPool{}, nil
if concurrency < 1 {
return nil, fmt.Errorf("%w:concurrency应该大于0", errInvalidArgument)
}
if queueSize < 0 {
return nil, fmt.Errorf("%w:queueSize应该大于等于0", errInvalidArgument)
}
b := &BlockQueueTaskPool{
queue: make(chan Task, queueSize),
token: make(chan struct{}, concurrency),
done: make(chan struct{}),
}
b.ctx, b.cancelFunc = context.WithCancel(context.Background())
atomic.StoreInt32(&b.state, stateCreated)
return b, nil
}

// Submit 提交一个任务
// 如果此时队列已满,那么将会阻塞调用者。
// 如果因为 ctx 的原因返回,那么将会返回 ctx.Err()
// 在调用 Start 前后都可以调用 Submit
func (b *BlockQueueTaskPool) Submit(ctx context.Context, task func()) error {
// TODO implement me
panic("implement me")
func (b *BlockQueueTaskPool) Submit(ctx context.Context, task Task) error {
if task == nil || reflect.ValueOf(task).IsNil() {
flycash marked this conversation as resolved.
Show resolved Hide resolved
return fmt.Errorf("%w", errTaskIsInvalid)
}
// todo: 用户未设置超时,可以考虑内部给个超时提交
for {

if atomic.LoadInt32(&b.state) == stateClosing {
return fmt.Errorf("%w", errTaskPoolIsClosing)
}

if atomic.LoadInt32(&b.state) == stateStopped {
return fmt.Errorf("%w", errTaskPoolIsStopped)
}

task = &taskWrapper{t: task}

ok, err := b.trySubmit(ctx, task, stateCreated)
if ok || err != nil {
return err
}

ok, err = b.trySubmit(ctx, task, stateRunning)
if ok || err != nil {
return err
}
}
}

func (b *BlockQueueTaskPool) trySubmit(ctx context.Context, task Task, state int32) (bool, error) {
// 进入临界区
if atomic.CompareAndSwapInt32(&b.state, state, stateLocked) {
flycash marked this conversation as resolved.
Show resolved Hide resolved
defer atomic.CompareAndSwapInt32(&b.state, stateLocked, state)

// 此处b.queue <- task不会因为b.queue被关闭而panic
// 代码执行到trySubmit时TaskPool处于lock状态
// 要关闭b.queue需要TaskPool处于RUNNING状态,Shutdown/ShutdownNow才能成功
select {
case <-ctx.Done():
return false, fmt.Errorf("%w", ctx.Err())
case b.queue <- task:
return true, nil
default:
// 不能阻塞在临界区
}
return false, nil
}
return false, nil
}

// Start 开始调度任务执行
// Start 之后,调用者可以继续使用 Submit 提交任务
func (b *BlockQueueTaskPool) Start() error {
// TODO implement me
panic("implement me")

for {

if atomic.LoadInt32(&b.state) == stateClosing {
return fmt.Errorf("%w", errTaskPoolIsClosing)
}

if atomic.LoadInt32(&b.state) == stateStopped {
return fmt.Errorf("%w", errTaskPoolIsStopped)
}

if atomic.LoadInt32(&b.state) == stateRunning {
// 重复调用,不予处理
return nil
}

if atomic.CompareAndSwapInt32(&b.state, stateCreated, stateRunning) {
go b.startTasks()
flycash marked this conversation as resolved.
Show resolved Hide resolved
return nil
}
}
}

func (b *BlockQueueTaskPool) startTasks() {
defer close(b.token)

for {
select {
case <-b.ctx.Done():
return
case b.token <- struct{}{}:

task := <-b.queue
// handle close(b.queue)
if task == nil {
flycash marked this conversation as resolved.
Show resolved Hide resolved
return
}

go func() {

atomic.AddInt32(&b.num, 1)
defer func() {
atomic.AddInt32(&b.num, -1)
<-b.token
}()

// todo: handle err
err := task.Run(b.ctx)
if err != nil {
return
}
}()
}
}
}

// Shutdown 将会拒绝提交新的任务,但是会继续执行已提交任务
// 当执行完毕后,会往返回的 chan 中丢入信号
// Shutdown 会负责关闭返回的 chan
// Shutdown 无法中断正在执行的任务
func (b *BlockQueueTaskPool) Shutdown() (<-chan struct{}, error) {
// TODO implement me
panic("implement me")

for {

if atomic.LoadInt32(&b.state) == stateCreated {
return nil, fmt.Errorf("%w", errTaskPoolIsNotRunning)
}

if atomic.LoadInt32(&b.state) == stateStopped {
// 重复调用时,恰好前一个Shutdown调用将状态迁移为StateStopped
// 这种情况与先调用ShutdownNow状态迁移为StateStopped再调用Shutdown效果一样
return nil, fmt.Errorf("%w", errTaskPoolIsStopped)
}

if atomic.LoadInt32(&b.state) == stateClosing {
// 重复调用
return b.done, nil
}

if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateClosing) {
// 目标:不但希望正在运行中的任务自然退出,还希望队列中等待的任务也能启动执行并自然退出
// 策略:先将队列中的任务启动并执行(清空队列),再等待全部运行中的任务自然退出

// 先关闭等待队列不再允许提交
// 同时任务启动循环能够通过Task==nil来终止循环
close(b.queue)

go func() {
// 等待运行中的Task自然结束
for atomic.LoadInt32(&b.num) != 0 {
time.Sleep(time.Second)
}
// 通知外部调用者
close(b.done)
// 完成最终的状态迁移
atomic.CompareAndSwapInt32(&b.state, stateClosing, stateStopped)
flycash marked this conversation as resolved.
Show resolved Hide resolved
}()
return b.done, nil
}

}
}

// ShutdownNow 立刻关闭任务池,并且返回所有剩余未执行的任务(不包含正在执行的任务)
func (b *BlockQueueTaskPool) ShutdownNow() ([]Task, error) {
// TODO implement me
panic("implement me")

for {

if atomic.LoadInt32(&b.state) == stateCreated {
return nil, fmt.Errorf("%w", errTaskPoolIsNotRunning)
}

if atomic.LoadInt32(&b.state) == stateClosing {
return nil, fmt.Errorf("%w", errTaskPoolIsClosing)
}

if atomic.LoadInt32(&b.state) == stateStopped {
flycash marked this conversation as resolved.
Show resolved Hide resolved
// 重复调用,返回缓存结果
b.mux.RLock()
tasks := append([]Task(nil), b.submittedTasks...)
b.mux.RUnlock()
return tasks, nil
}
if atomic.CompareAndSwapInt32(&b.state, stateRunning, stateStopped) {
// 目标:立刻关闭并且返回所有剩下未执行的任务
// 策略:关闭等待队列不再接受新任务,中断任务启动循环,清空等待队列并保存返回

close(b.queue)

// 发送中断信号,中断任务启动循环
b.cancelFunc()

b.mux.Lock()
// 清空队列并保存
var tasks []Task
flycash marked this conversation as resolved.
Show resolved Hide resolved
for task := range b.queue {
flycash marked this conversation as resolved.
Show resolved Hide resolved
b.submittedTasks = append(b.submittedTasks, task)
tasks = append(tasks, task)
}
b.mux.Unlock()

return tasks, nil
}
}
}

// internalState 用于查看TaskPool状态
func (b *BlockQueueTaskPool) internalState() int32 {
for {
state := atomic.LoadInt32(&b.state)
if state != stateLocked {
return state
}
}
}

func (b *BlockQueueTaskPool) NumGo() int32 {
return atomic.LoadInt32(&b.num)
}