Skip to content

Commit

Permalink
code refactoring. improve tests
Browse files Browse the repository at this point in the history
  • Loading branch information
o1egl committed Apr 10, 2017
1 parent 2f9cf65 commit 402a813
Show file tree
Hide file tree
Showing 5 changed files with 73 additions and 69 deletions.
49 changes: 25 additions & 24 deletions README.md
Expand Up @@ -20,34 +20,35 @@ $ go get -u github.com/o1egl/syncx

AWG - advanced version of wait group

Added features:
Added features:

* thread safe
* integrated context.Context
* execution timeout
* ability to return errors
* panic handling from goroutines

```go
wg := awg.AdvancedWaitGroup{}
// Add function one

// Add first task
wg.Add(func() error {
//Logic
return nil
})
// Add function two

// Add second task
wg.Add(func() error {
//Another Logic
return nil
})

wg.Start()

var err error
// Taking one error make sense if you use *.SetStopOnError(true)* option - see below
err = wg.SetStopOnError(true).Start().GetLastError()

// Taking all errors
var errs []error
errs = wg.Start().GetAllErrors()
Expand All @@ -59,28 +60,28 @@ Integrated with context.Context. It gives you ability to set timeouts and regist
```go
// SetTimeout defines timeout for all tasks
SetTimeout(t time.Duration)

// SetContext defines Context
SetContext(t context.Context)
// SetStopOnError stops wiatgroup if any task returns error

// SetStopOnError stops waitgroup if any task returns error
SetStopOnError(b bool)

// Add adds new tasks into waitgroup
Add(funcs ...WaitgroupFunc)

// // Start runs tasks in separate goroutines and waits for their completion
Start()

// Reset performs cleanup task queue and reset state
Reset()

// // LastError returns last error that caught by execution process
LastError()

// AllErrors returns all errors that caught by execution process
AllErrors()

// Status returns result state
Status()
```
Expand All @@ -91,24 +92,24 @@ Integrated with context.Context. It gives you ability to set timeouts and regist
```go
// NewSemaphore returns new Semaphore instance
NewSemaphore(10)

//Acquire acquires one permit, if its not available the goroutine will block till its available or Context.Done() occurs.
//You can pass context.WithTimeout() to support timeoutable acquire.
Acquire(ctx context.Context)

//AcquireMany is similar to Acquire() but for many permits
//Returns successfully acquired permits.
AcquireMany(ctx context.Context, n int) (int, error)

//Release releases one permit
Release()

//ReleaseMany releases many permits
ReleaseMany(n int) error

//AvailablePermits returns number of available unacquired permits
AvailablePermits() int

//DrainPermits acquires all available permits and return the number of permits acquired
DrainPermits() (int, error)
```
Expand All @@ -123,4 +124,4 @@ Integrated with context.Context. It gives you ability to set timeouts and regist
2. Open a [Pull Request](https://github.com/o1egl/syncx/pulls)
3. Enjoy a refreshing Diet Coke and wait

SyncX is released under the MIT license. See [LICENSE](LICENSE)
SyncX is released under the MIT license. See [LICENSE](LICENSE)
15 changes: 8 additions & 7 deletions semaphore.go
Expand Up @@ -5,6 +5,7 @@ import (
"errors"
)

// Semaphore defines semaphore interface
type Semaphore interface {
//Acquire acquires one permit, if its not available the goroutine will block till its available or Context.Done() occurs.
//You can pass context.WithTimeout() to support timeoutable acquire.
Expand All @@ -30,7 +31,7 @@ type Semaphore interface {
// NewSemaphore returns new Semaphore instance
func NewSemaphore(permits int) (*semaphore, error) {
if permits < 1 {
return nil, errors.New("Invalid number of permits. Less than 1")
return nil, errors.New("invalid number of permits. Less than 1")
}
return &semaphore{
channel: make(chan struct{}, permits),
Expand All @@ -46,17 +47,17 @@ func (s *semaphore) Acquire(ctx context.Context) error {
case s.channel <- struct{}{}:
return nil
case <-ctx.Done():
return errors.New("Acquire canceled.")
return errors.New("acquire canceled")
}

}

func (s *semaphore) AcquireMany(ctx context.Context, n int) (int, error) {
if n < 0 {
return 0, errors.New("Acquir count coundn't be negative")
return 0, errors.New("acquir count coundn't be negative")
}
if n > s.totalPermits() {
return 0, errors.New("To many requested permits")
return 0, errors.New("too many requested permits")
}
acquired := 0
for ; n > 0; n-- {
Expand All @@ -65,7 +66,7 @@ func (s *semaphore) AcquireMany(ctx context.Context, n int) (int, error) {
acquired++
continue
case <-ctx.Done():
return acquired, errors.New("Acquire canceled.")
return acquired, errors.New("acquire canceled")
}

}
Expand All @@ -90,10 +91,10 @@ func (s *semaphore) Release() {

func (s *semaphore) ReleaseMany(n int) error {
if n < 0 {
return errors.New("Release count coundn't be negative")
return errors.New("release count coundn't be negative")
}
if n > s.totalPermits() {
return errors.New("Too many requested releases")
return errors.New("too many requested releases")
}
for ; n > 0; n-- {
s.Release()
Expand Down
3 changes: 2 additions & 1 deletion semaphore_test.go
Expand Up @@ -2,10 +2,11 @@ package syncx

import (
"context"
"github.com/stretchr/testify/assert"
"math"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

func TestNewSemaphore(t *testing.T) {
Expand Down
18 changes: 9 additions & 9 deletions wait_group.go
Expand Up @@ -2,13 +2,13 @@ package syncx

import (
"context"
"errors"
"fmt"
"runtime"
"sync"
"time"
)

// Status represents AdvancedWaitGroup status
type Status int

const (
Expand All @@ -27,11 +27,11 @@ const (
// WaitgroupFunc func
type WaitgroupFunc func() error

// AdvancedWaitGroup enhanced wait group struct
// AdvancedWaitGroup enhanced wait group struct. You can use it in different goroutines. It's thread safe.
type AdvancedWaitGroup struct {
sync.RWMutex
context context.Context
stack []WaitgroupFunc
tasks []WaitgroupFunc
stopOnError bool
status Status
errors []error
Expand All @@ -57,7 +57,7 @@ func (wg *AdvancedWaitGroup) SetContext(t context.Context) *AdvancedWaitGroup {
return wg
}

// SetStopOnError stops wiatgroup if any task returns error
// SetStopOnError stops waitgroup if any task returns error
func (wg *AdvancedWaitGroup) SetStopOnError(b bool) *AdvancedWaitGroup {
wg.Lock()
wg.stopOnError = b
Expand All @@ -68,7 +68,7 @@ func (wg *AdvancedWaitGroup) SetStopOnError(b bool) *AdvancedWaitGroup {
// Add adds new tasks into waitgroup
func (wg *AdvancedWaitGroup) Add(funcs ...WaitgroupFunc) *AdvancedWaitGroup {
wg.Lock()
wg.stack = append(wg.stack, funcs...)
wg.tasks = append(wg.tasks, funcs...)
wg.Unlock()
return wg
}
Expand All @@ -79,12 +79,12 @@ func (wg *AdvancedWaitGroup) Start() *AdvancedWaitGroup {
defer wg.Unlock()
wg.status = StatusSuccess

if taskCount := len(wg.stack); taskCount > 0 {
if taskCount := len(wg.tasks); taskCount > 0 {
failed := make(chan error, taskCount)
done := make(chan bool, taskCount)

StarterLoop:
for _, f := range wg.stack {
for _, f := range wg.tasks {
// check if context is canceled
select {
case <-wg.doneChannel():
Expand All @@ -98,7 +98,7 @@ func (wg *AdvancedWaitGroup) Start() *AdvancedWaitGroup {
if r := recover(); r != nil {
buf := make([]byte, 1000)
runtime.Stack(buf, false)
failed <- errors.New(fmt.Sprintf("Panic handeled\n%v\n%s", r, string(buf)))
failed <- fmt.Errorf("Panic handeled\n%v\n%s", r, string(buf))
}
}()

Expand Down Expand Up @@ -146,7 +146,7 @@ func (wg *AdvancedWaitGroup) doneChannel() <-chan struct{} {
// Reset performs cleanup task queue and reset state
func (wg *AdvancedWaitGroup) Reset() {
wg.Lock()
wg.stack = nil
wg.tasks = nil
wg.stopOnError = false
wg.status = StatusIdle
wg.errors = nil
Expand Down
57 changes: 29 additions & 28 deletions wait_group_test.go
Expand Up @@ -3,13 +3,14 @@ package syncx
import (
"context"
"errors"
"github.com/stretchr/testify/assert"
"runtime"
"testing"
"time"

"github.com/stretchr/testify/assert"
)

var testError = errors.New("Test error")
var errTest = errors.New("test error")

func slowFunc() error {
time.Sleep(2 * time.Second)
Expand All @@ -23,28 +24,33 @@ func fastFunc() error {
}

func errorFunc() error {
return testError
return errTest
}

func panicFunc() error {
panic("Test panic")
return nil
panic("test panic")
}

func incFunc(v *int) func() error {
return func() error {
*v++
return nil
}
}

// Test_AdvancedWaitGroup_Success test for success case
func Test_AdvancedWaitGroup_Success(t *testing.T) {
var wg AdvancedWaitGroup
var r1, r2, r3, r4 int
wg.Add(incFunc(&r1))
wg.Add(incFunc(&r2))
wg.Add(incFunc(&r3))
wg.Add(incFunc(&r4))

wg.Add(fastFunc)
wg.Add(fastFunc)
wg.Add(slowFunc)
wg.Add(slowFunc)

start := time.Now()
wg.Start()
diff := time.Now().Sub(start).Nanoseconds()
sum := r1 + r2 + r3 + r4

assert.True(t, diff >= (2*time.Second).Nanoseconds(), "AWG should wait all goroutines")
assert.Equal(t, 4, sum)
assert.Equal(t, StatusSuccess, wg.Status())
assert.NoError(t, wg.LastError())
assert.Len(t, wg.AllErrors(), 0)
Expand All @@ -54,20 +60,18 @@ func Test_AdvancedWaitGroup_Success(t *testing.T) {
func Test_AdvancedWaitGroup_SuccessWithErrors(t *testing.T) {
var wg AdvancedWaitGroup

wg.Add(fastFunc)
wg.Add(fastFunc)
var r1, r2 int
wg.Add(errorFunc)
wg.Add(incFunc(&r1))
wg.Add(incFunc(&r2))
wg.Add(errorFunc)
wg.Add(slowFunc)
wg.Add(slowFunc)

start := time.Now()
wg.Start()
diff := time.Now().Sub(start).Nanoseconds()
sum := r1 + r2

assert.True(t, diff >= (2*time.Second).Nanoseconds(), "AWG should wait all goroutines")
assert.Equal(t, 2, sum)
assert.Equal(t, StatusSuccess, wg.Status())
assert.Error(t, testError, wg.LastError())
assert.Error(t, errTest, wg.LastError())
assert.Len(t, wg.AllErrors(), 2)
}

Expand Down Expand Up @@ -122,11 +126,8 @@ func Test_AdvancedWaitGroup_DontExecCanceled(t *testing.T) {
var wg AdvancedWaitGroup
ctx, cancelFunc := context.WithCancel(context.Background())

i := 0
wg.Add(func() error {
i++
return nil
})
var i int
wg.Add(incFunc(&i))

cancelFunc()

Expand All @@ -150,7 +151,7 @@ func Test_AdvancedWaitGroup_StopOnError(t *testing.T) {
diff := time.Now().Sub(start).Nanoseconds()

assert.True(t, diff < (time.Second).Nanoseconds(), "AWG should be canceled immediately")
assert.Equal(t, testError, wg.LastError())
assert.Equal(t, errTest, wg.LastError())
assert.Equal(t, StatusError, wg.Status(), "AWG status should be StatusError!")
}

Expand All @@ -165,7 +166,7 @@ func Test_AdvancedWaitGroup_Panic(t *testing.T) {
Start()

assert.Equal(t, StatusError, wg.Status())
assert.Contains(t, wg.LastError().Error(), "Test panic")
assert.Contains(t, wg.LastError().Error(), "test panic")
}

// Test_AdvancedWaitGroup_Reset test for reset
Expand Down

0 comments on commit 402a813

Please sign in to comment.