Skip to content

Commit

Permalink
Refactor worker pool for better scalability and readability
Browse files Browse the repository at this point in the history
- Split `executeTask` function into `executeTaskWithTimeout` and `executeTaskWithoutTimeout` for better readability.
- Move worker pool scaling logic into a separate goroutine that runs periodically, improving scalability and making the `dispatch` function simpler.
- Add `retryCount` and `adjustInterval` fields to the `goPool` struct to support task retry and adjustable worker scaling intervals.
- Update tests and README to reflect these changes.

Signed-off-by: Daniel Hu <tao.hu@merico.dev>
  • Loading branch information
daniel-hutao committed Jul 25, 2023
1 parent 5a7dd13 commit 4974182
Show file tree
Hide file tree
Showing 5 changed files with 107 additions and 54 deletions.
2 changes: 0 additions & 2 deletions README.md
Expand Up @@ -22,8 +22,6 @@ GoPool is a high-performance, feature-rich, and easy-to-use worker pool library

- **Task Retry**: GoPool provides a retry mechanism for failed tasks.

- **Task Progress Tracking**: GoPool provides task progress tracking.

- **Concurrency Control**: GoPool can control the number of concurrent tasks to prevent system overload.

- **Lock Customization**: GoPool supports different types of locks. You can use the built-in `sync.Mutex` or a custom lock such as `spinlock.SpinLock`.
Expand Down
55 changes: 45 additions & 10 deletions gopool.go
Expand Up @@ -16,11 +16,13 @@ type goPool struct {
minWorkers int
workerStack []int
taskQueue chan task
retryCount int
lock sync.Locker
cond *sync.Cond
timeout time.Duration
resultCallback func(interface{})
errorCallback func(error)
adjustInterval time.Duration
}

// NewGoPool creates a new pool of workers.
Expand All @@ -31,8 +33,10 @@ func NewGoPool(maxWorkers int, opts ...Option) *goPool {
workers: make([]*worker, maxWorkers),
workerStack: make([]int, maxWorkers),
taskQueue: make(chan task, 1e6),
retryCount: 0,
lock: new(sync.Mutex),
timeout: 0,
adjustInterval: 1 * time.Second,
}
for _, opt := range opts {
opt(pool)
Expand All @@ -46,6 +50,7 @@ func NewGoPool(maxWorkers int, opts ...Option) *goPool {
pool.workerStack[i] = i
worker.start(pool, i)
}
go pool.adjustWorkers()
go pool.dispatch()
return pool
}
Expand All @@ -59,7 +64,7 @@ func (p *goPool) AddTask(t task) {
func (p *goPool) Release() {
close(p.taskQueue)
p.cond.L.Lock()
for len(p.workerStack) != p.maxWorkers {
for len(p.workerStack) != p.minWorkers {
p.cond.Wait()
}
p.cond.L.Unlock()
Expand All @@ -85,6 +90,31 @@ func (p *goPool) pushWorker(workerIndex int) {
p.cond.Signal()
}

func (p *goPool) adjustWorkers() {
ticker := time.NewTicker(p.adjustInterval)
defer ticker.Stop()

for range ticker.C {
p.cond.L.Lock()
if len(p.taskQueue) > (p.maxWorkers-p.minWorkers)/2+p.minWorkers && len(p.workerStack) < p.maxWorkers {
// Double the number of workers until it reaches the maximum
newWorkers := min(len(p.workerStack)*2, p.maxWorkers) - len(p.workerStack)
for i := 0; i < newWorkers; i++ {
worker := newWorker()
p.workers = append(p.workers, worker)
p.workerStack = append(p.workerStack, len(p.workers)-1)
worker.start(p, len(p.workers)-1)
}
} else if len(p.taskQueue) < p.minWorkers && len(p.workerStack) > p.minWorkers {
// Halve the number of workers until it reaches the minimum
removeWorkers := max((len(p.workerStack)-p.minWorkers)/2, p.minWorkers)
p.workers = p.workers[:len(p.workers)-removeWorkers]
p.workerStack = p.workerStack[:len(p.workerStack)-removeWorkers]
}
p.cond.L.Unlock()
}
}

func (p *goPool) dispatch() {
for t := range p.taskQueue {
p.cond.L.Lock()
Expand All @@ -94,14 +124,19 @@ func (p *goPool) dispatch() {
p.cond.L.Unlock()
workerIndex := p.popWorker()
p.workers[workerIndex].taskQueue <- t
if len(p.taskQueue) > (p.maxWorkers-p.minWorkers)/2+p.minWorkers && len(p.workerStack) < p.maxWorkers {
worker := newWorker()
p.workers = append(p.workers, worker)
p.workerStack = append(p.workerStack, len(p.workers)-1)
worker.start(p, len(p.workers)-1)
} else if len(p.taskQueue) < p.minWorkers && len(p.workerStack) > p.minWorkers {
p.workers = p.workers[:len(p.workers)-1]
p.workerStack = p.workerStack[:len(p.workerStack)-1]
}
}
}

func min(a, b int) int {
if a < b {
return a
}
return b
}

func max(a, b int) int {
if a > b {
return a
}
return b
}
4 changes: 2 additions & 2 deletions gopool_test.go
Expand Up @@ -34,7 +34,7 @@ func TestGoPoolWithSpinLock(t *testing.T) {
func BenchmarkGoPoolWithMutex(b *testing.B) {
var wg sync.WaitGroup
var taskNum = int(1e6)
pool := NewGoPool(5e4, WithLock(new(sync.Mutex)))
pool := NewGoPool(1e4, WithLock(new(sync.Mutex)))

b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand All @@ -55,7 +55,7 @@ func BenchmarkGoPoolWithMutex(b *testing.B) {
func BenchmarkGoPoolWithSpinLock(b *testing.B) {
var wg sync.WaitGroup
var taskNum = int(1e6)
pool := NewGoPool(5e4, WithLock(new(spinlock.SpinLock)))
pool := NewGoPool(1e4, WithLock(new(spinlock.SpinLock)))

b.ResetTimer()
for i := 0; i < b.N; i++ {
Expand Down
7 changes: 7 additions & 0 deletions option.go
Expand Up @@ -43,3 +43,10 @@ func WithErrorCallback(callback func(error)) Option {
p.errorCallback = callback
}
}

// WithRetryCount sets the retry count for the pool.
func WithRetryCount(retryCount int) Option {
return func(p *goPool) {
p.retryCount = retryCount
}
}
93 changes: 53 additions & 40 deletions worker.go
Expand Up @@ -20,49 +20,62 @@ func (w *worker) start(pool *goPool, workerIndex int) {
go func() {
for t := range w.taskQueue {
if t != nil {
var result interface{}
var err error
result, err := w.executeTask(t, pool)
w.handleResult(result, err, pool)
}
pool.pushWorker(workerIndex)
}
}()
}

if pool.timeout > 0 {
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), pool.timeout)
defer cancel()
func (w *worker) executeTask(t task, pool *goPool) (result interface{}, err error) {
for i := 0; i <= pool.retryCount; i++ {
if pool.timeout > 0 {
result, err = w.executeTaskWithTimeout(t, pool)
} else {
result, err = w.executeTaskWithoutTimeout(t, pool)
}
if err == nil || i == pool.retryCount {
return result, err
}
}
return
}

// Create a channel to receive the result of the task
done := make(chan struct{})
func (w *worker) executeTaskWithTimeout(t task, pool *goPool) (result interface{}, err error) {
// Create a context with timeout
ctx, cancel := context.WithTimeout(context.Background(), pool.timeout)
defer cancel()

// Run the task in a separate goroutine
go func() {
result, err = t()
close(done)
}()
// Create a channel to receive the result of the task
done := make(chan struct{})

// Wait for the task to finish or for the context to timeout
select {
case <-done:
// The task finished successfully
if err != nil && pool.errorCallback != nil {
pool.errorCallback(err)
} else if pool.resultCallback != nil {
pool.resultCallback(result)
}
case <-ctx.Done():
// The context timed out, the task took too long
if pool.errorCallback != nil {
pool.errorCallback(fmt.Errorf("Task timed out"))
}
}
} else {
// If timeout is not set or is zero, just run the task
result, err = t()
if err != nil && pool.errorCallback != nil {
pool.errorCallback(err)
} else if pool.resultCallback != nil {
pool.resultCallback(result)
}
}
}
pool.pushWorker(workerIndex)
}
// Run the task in a separate goroutine
go func() {
result, err = t()
close(done)
}()

// Wait for the task to finish or for the context to timeout
select {
case <-done:
// The task finished successfully
return result, err
case <-ctx.Done():
// The context timed out, the task took too long
return nil, fmt.Errorf("Task timed out")
}
}

func (w *worker) executeTaskWithoutTimeout(t task, pool *goPool) (result interface{}, err error) {
// If timeout is not set or is zero, just run the task
return t()
}

func (w *worker) handleResult(result interface{}, err error, pool *goPool) {
if err != nil && pool.errorCallback != nil {
pool.errorCallback(err)
} else if pool.resultCallback != nil {
pool.resultCallback(result)
}
}

0 comments on commit 4974182

Please sign in to comment.