diff --git a/.gitignore b/.gitignore index 66fd13c..2d267a4 100644 --- a/.gitignore +++ b/.gitignore @@ -13,3 +13,6 @@ # Dependency directories (remove the comment below to include it) # vendor/ + +cover-profile.out +cover-coverage.html diff --git a/Makefile b/Makefile index af5547a..1bec379 100644 --- a/Makefile +++ b/Makefile @@ -3,6 +3,11 @@ test: clear go test -timeout 30s -count=1 -cover ./... +cover: + clear + go test -count=1 -timeout 10s -coverprofile=cover-profile.out -covermode=set -coverpkg=./... ./...; \ + go tool cover -html=cover-profile.out -o cover-coverage.html + lint: clear golangci-lint run ./... diff --git a/default-runner_test.go b/default-runner_test.go index 69ab64e..03c9a39 100644 --- a/default-runner_test.go +++ b/default-runner_test.go @@ -4,41 +4,52 @@ import ( "testing" "time" - "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/suite" ) -var _ Runner = new(defaultRunner) +var ( + _defaultRunner *defaultRunner + _ Runner = _defaultRunner +) func Test_defaultRunner(t *testing.T) { - t.Run(`should run the job`, func(t *testing.T) { - sut := newDefaultRunner() - - done := make(chan struct{}) - sut.Run(func() { close(done) }) - - assert.Eventually(t, func() bool { - select { - case <-done: - return true - default: - } - return false - }, time.Millisecond*100, time.Millisecond*10) - }) - - t.Run(`should run the job concurrently`, func(t *testing.T) { - sut := newDefaultRunner() - - done := make(chan struct{}) - sut.Run(func() { done <- struct{}{} }) - - assert.Eventually(t, func() bool { - select { - case <-done: - return true - default: - } - return false - }, time.Millisecond*100, time.Millisecond*10) - }) + suite.Run(t, new(suiteDefaultRunner)) +} + +type suiteDefaultRunner struct { + suite.Suite + + sut *defaultRunner +} + +func (obj *suiteDefaultRunner) SetupTest() { + obj.sut = newDefaultRunner() +} + +func (obj *suiteDefaultRunner) Test_Run_should_run_the_job() { + done := make(chan struct{}) + obj.sut.Run(func() { close(done) }) + + obj.Eventually(func() bool { + select { + case <-done: + return true + default: + } + return false + }, time.Millisecond*100, time.Millisecond*10) +} + +func (obj *suiteDefaultRunner) Test_Run_should_run_the_job_concurrently() { + done := make(chan struct{}) + obj.sut.Run(func() { done <- struct{}{} }) + + obj.Eventually(func() bool { + select { + case <-done: + return true + default: + } + return false + }, time.Millisecond*100, time.Millisecond*10) } diff --git a/subscriber.go b/subscriber.go index d548313..3989422 100644 --- a/subscriber.go +++ b/subscriber.go @@ -98,30 +98,33 @@ func (obj *Subscriber) Shutdown() { obj.cancel() } func (obj *Subscriber) runHandler(ctx context.Context, msg types.Message) { obj.Runner.Run(func() { - ctx = obj.runBefore(ctx, msg) + scopedCtx, cancel := context.WithCancel(ctx) + defer cancel() - req, err := obj.DecodeRequest(ctx, msg) + scopedCtx = obj.runBefore(scopedCtx, msg) + + req, err := obj.DecodeRequest(scopedCtx, msg) if err != nil { err := &DecoderError{ Err: err, Msg: msg, } - obj.notifyError(ctx, err) + obj.notifyError(scopedCtx, err) return } - resp, err := obj.Handler(ctx, req) + resp, err := obj.Handler(scopedCtx, req) if err != nil { err := &HandlerError{ Err: err, Request: req, Msg: msg, } - obj.notifyError(ctx, err) + obj.notifyError(scopedCtx, err) return } - obj.runResponseHandler(ctx, msg, resp) + obj.runResponseHandler(scopedCtx, msg, resp) }) } diff --git a/subscriber_test.go b/subscriber_test.go index 9cd5527..31c52f7 100644 --- a/subscriber_test.go +++ b/subscriber_test.go @@ -553,6 +553,8 @@ func Test_Subscriber_should_call_AfterBatch_after_calling_the_handler_for_receiv } func Test_Subscriber_should_panic_if_any_before_function_returns_a_nil_context(t *testing.T) { + t.Parallel() + sut := &Subscriber{ Before: []RequestFunc{ func(ctx context.Context, msg types.Message) context.Context { @@ -564,6 +566,56 @@ func Test_Subscriber_should_panic_if_any_before_function_returns_a_nil_context(t assert.Panics(t, func() { sut.runBefore(context.Background(), types.Message{}) }) } +func Test_Subscriber_should_panic_if_any_response_handler_function_returns_a_nil_context(t *testing.T) { + t.Parallel() + + sut := &Subscriber{ + ResponseHandler: []ResponseFunc{ + func(ctx context.Context, msg types.Message, response interface{}) context.Context { + return nil + }, + }, + } + + assert.Panics(t, func() { sut.runResponseHandler(context.Background(), types.Message{}, nil) }) +} + +func Test_Subscriber_runHandler_should_create_a_request_scoped_context(t *testing.T) { + t.Parallel() + + gotCtx := make(chan context.Context, 1) + + sut := &Subscriber{ + DecodeRequest: func(c context.Context, m types.Message) (request interface{}, err error) { return m, nil }, + Handler: func(ctx context.Context, request interface{}) (response interface{}, err error) { + return "OK", nil + }, + ResponseHandler: []ResponseFunc{ + func(ctx context.Context, msg types.Message, response interface{}) context.Context { + gotCtx <- ctx + return ctx + }, + }, + InputFactory: defaultInputFactory, + } + _ = sut.init() + + msg := types.Message{ + Body: aws.String("a message"), + } + + sut.runHandler(context.Background(), msg) + + assert.Eventually(t, func() bool { + select { + case ctx := <-gotCtx: + return ctx.Err() == context.Canceled + default: + } + return false + }, time.Millisecond*300, time.Millisecond*20) +} + func Test_Subscriber_init(t *testing.T) { t.Parallel()