Skip to content

Commit

Permalink
Make serveMessage return func
Browse files Browse the repository at this point in the history
  • Loading branch information
Omar Qurie committed Oct 26, 2020
1 parent 9d978f3 commit b673cbe
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 51 deletions.
88 changes: 45 additions & 43 deletions transport/awssqs/consumer.go
Original file line number Diff line number Diff line change
Expand Up @@ -129,57 +129,59 @@ func ConsumerDeleteMessageAfter() ConsumerOption {
}

// ServeMessage serves an SQS message.
func (c Consumer) ServeMessage(ctx context.Context, msg *sqs.Message) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if len(c.finalizer) > 0 {
defer func() {
for _, f := range c.finalizer {
f(ctx, msg)
}
}()
}
func (c Consumer) ServeMessage(ctx context.Context) func(msg *sqs.Message) error {
return func(msg *sqs.Message) error {
ctx, cancel := context.WithCancel(ctx)
defer cancel()

if len(c.finalizer) > 0 {
defer func() {
for _, f := range c.finalizer {
f(ctx, msg)
}
}()
}

for _, f := range c.before {
ctx = f(ctx, cancel, msg)
}
for _, f := range c.before {
ctx = f(ctx, cancel, msg)
}

req, err := c.dec(ctx, msg)
if err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}
req, err := c.dec(ctx, msg)
if err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}

response, err := c.e(ctx, req)
if err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}
response, err := c.e(ctx, req)
if err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}

responseMsg := sqs.SendMessageInput{}
for _, f := range c.after {
ctx = f(ctx, cancel, msg, &responseMsg)
}
responseMsg := sqs.SendMessageInput{}
for _, f := range c.after {
ctx = f(ctx, cancel, msg, &responseMsg)
}

if !c.wantRep(ctx, msg) {
return nil
}
if !c.wantRep(ctx, msg) {
return nil
}

if err := c.enc(ctx, &responseMsg, response); err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}
if err := c.enc(ctx, &responseMsg, response); err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}

if _, err := c.sqsClient.SendMessageWithContext(ctx, &responseMsg); err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
if _, err := c.sqsClient.SendMessageWithContext(ctx, &responseMsg); err != nil {
c.errorHandler.Handle(ctx, err)
c.errorEncoder(ctx, err, msg, c.sqsClient)
return err
}
return nil
}
return nil
}

// ErrorEncoder is responsible for encoding an error to the consumer's reply.
Expand Down
16 changes: 8 additions & 8 deletions transport/awssqs/consumer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ func TestConsumerDeleteBefore(t *testing.T) {
awssqs.ConsumerDeleteMessageBefore(),
)

consumer.ServeMessage(context.Background(), &sqs.Message{
consumer.ServeMessage(context.Background())(&sqs.Message{
Body: aws.String("MessageBody"),
MessageId: aws.String("fakeMsgID"),
})
Expand Down Expand Up @@ -113,7 +113,7 @@ func TestConsumerBadDecode(t *testing.T) {
awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }),
)

consumer.ServeMessage(context.Background(), &sqs.Message{
consumer.ServeMessage(context.Background())(&sqs.Message{
Body: aws.String("MessageBody"),
MessageId: aws.String("fakeMsgID"),
})
Expand Down Expand Up @@ -162,7 +162,7 @@ func TestConsumerBadEndpoint(t *testing.T) {
awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }),
)

consumer.ServeMessage(context.Background(), &sqs.Message{
consumer.ServeMessage(context.Background())(&sqs.Message{
Body: aws.String("MessageBody"),
MessageId: aws.String("fakeMsgID"),
})
Expand Down Expand Up @@ -211,7 +211,7 @@ func TestConsumerBadEncoder(t *testing.T) {
awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }),
)

consumer.ServeMessage(context.Background(), &sqs.Message{
consumer.ServeMessage(context.Background())(&sqs.Message{
Body: aws.String("MessageBody"),
MessageId: aws.String("fakeMsgID"),
})
Expand Down Expand Up @@ -255,7 +255,7 @@ func TestConsumerSuccess(t *testing.T) {
awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }),
)

consumer.ServeMessage(context.Background(), &sqs.Message{
consumer.ServeMessage(context.Background())(&sqs.Message{
Body: aws.String(string(b)),
MessageId: aws.String("fakeMsgID"),
})
Expand Down Expand Up @@ -303,7 +303,7 @@ func TestConsumerSuccessNoReply(t *testing.T) {
queueURL,
)

consumer.ServeMessage(context.Background(), &sqs.Message{
consumer.ServeMessage(context.Background())(&sqs.Message{
Body: aws.String(string(b)),
MessageId: aws.String("fakeMsgID"),
})
Expand Down Expand Up @@ -366,7 +366,7 @@ func TestConsumerBeforeAddValueToContext(t *testing.T) {
awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }),
)
ctx := context.Background()
err := consumer.ServeMessage(ctx, msg)
err := consumer.ServeMessage(ctx)(msg)
if err != nil {
t.Errorf("got err %s", err)
}
Expand Down Expand Up @@ -445,7 +445,7 @@ func TestConsumerAfter(t *testing.T) {
awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }),
)
ctx := context.Background()
consumer.ServeMessage(ctx, msg)
consumer.ServeMessage(ctx)(msg)

var receiveOutput *sqs.ReceiveMessageOutput
select {
Expand Down

0 comments on commit b673cbe

Please sign in to comment.