From ba659e3efd650d45ac0abf7989f6d4f71a154c18 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 7 Oct 2020 11:48:39 +0200 Subject: [PATCH 01/18] Add AWS' SQS support for transport --- transport/awssqs/consumer.go | 242 ++++++++++ transport/awssqs/consumer_test.go | 540 ++++++++++++++++++++++ transport/awssqs/encode_decode.go | 23 + transport/awssqs/publisher.go | 128 +++++ transport/awssqs/publisher_test.go | 373 +++++++++++++++ transport/awssqs/request_response_func.go | 39 ++ 6 files changed, 1345 insertions(+) create mode 100644 transport/awssqs/consumer.go create mode 100644 transport/awssqs/consumer_test.go create mode 100644 transport/awssqs/encode_decode.go create mode 100644 transport/awssqs/publisher.go create mode 100644 transport/awssqs/publisher_test.go create mode 100644 transport/awssqs/request_response_func.go diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go new file mode 100644 index 000000000..37f40a71a --- /dev/null +++ b/transport/awssqs/consumer.go @@ -0,0 +1,242 @@ +package awssqs + +import ( + "context" + "encoding/json" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/go-kit/kit/endpoint" + "github.com/go-kit/kit/log" + "github.com/go-kit/kit/transport" +) + +// Consumer wraps an endpoint and provides and provides a handler for sqs msgs +type Consumer struct { + sqsClient SQSClient + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + wantRep WantReplyFunc + queueURL *string + dlQueueURL *string + visibilityTimeout int64 + visibilityTimeoutFunc VisibilityTimeoutFunc + before []ConsumerRequestFunc + after []ConsumerResponseFunc + errorEncoder ErrorEncoder + finalizer []ConsumerFinalizerFunc + errorHandler transport.ErrorHandler +} + +// NewConsumer constructs a new Consumer, which provides a Consume method +// and message handlers that wrap the provided endpoint. +func NewConsumer( + sqsClient SQSClient, + e endpoint.Endpoint, + dec DecodeRequestFunc, + enc EncodeResponseFunc, + wantRep WantReplyFunc, + queueURL *string, + dlQueueURL *string, + visibilityTimeout int64, + options ...ConsumerOption, +) *Consumer { + s := &Consumer{ + sqsClient: sqsClient, + e: e, + dec: dec, + enc: enc, + wantRep: wantRep, + queueURL: queueURL, + dlQueueURL: dlQueueURL, + visibilityTimeout: visibilityTimeout, + visibilityTimeoutFunc: DoNotExtendVisibilityTimeout, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + } + for _, option := range options { + option(s) + } + return s +} + +// ConsumerOption sets an optional parameter for consumers. +type ConsumerOption func(*Consumer) + +// ConsumerBefore functions are executed on the publisher request object before the +// request is decoded. +func ConsumerBefore(before ...ConsumerRequestFunc) ConsumerOption { + return func(c *Consumer) { c.before = append(c.before, before...) } +} + +// ConsumerAfter functions are executed on the consumer reply after the +// endpoint is invoked, but before anything is published to the reply. +func ConsumerAfter(after ...ConsumerResponseFunc) ConsumerOption { + return func(c *Consumer) { c.after = append(c.after, after...) } +} + +// ConsumerErrorEncoder is used to encode errors to the consumer reply +// whenever they're encountered in the processing of a request. Clients can +// use this to provide custom error formatting. By default, +// errors will be published with the DefaultErrorEncoder. +func ConsumerErrorEncoder(ee ErrorEncoder) ConsumerOption { + return func(c *Consumer) { c.errorEncoder = ee } +} + +// ConsumerVisbilityTimeOutFunc is used to extend the visibility timeout +// for messages during when processing them. Clients can +// use this to provide custom visibility timeout extension. By default, +// visibility timeout are not extend. +func ConsumerVisbilityTimeOutFunc(vtFunc VisibilityTimeoutFunc) ConsumerOption { + return func(c *Consumer) { c.visibilityTimeoutFunc = vtFunc } +} + +// ConsumerErrorLogger is used to log non-terminal errors. By default, no errors +// are logged. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ConsumerErrorEncoder which has access to the context. +// Deprecated: Use ConsumerErrorHandler instead. +func ConsumerErrorLogger(logger log.Logger) ConsumerOption { + return func(c *Consumer) { c.errorHandler = transport.NewLogErrorHandler(logger) } +} + +// ConsumerErrorHandler is used to handle non-terminal errors. By default, non-terminal errors +// are ignored. This is intended as a diagnostic measure. Finer-grained control +// of error handling, including logging in more detail, should be performed in a +// custom ConsumerErrorEncoder which has access to the context. +func ConsumerErrorHandler(errorHandler transport.ErrorHandler) ConsumerOption { + return func(c *Consumer) { c.errorHandler = errorHandler } +} + +// ConsumerFinalizer is executed at the end of every request from a publisher through SQS. +// By default, no finalizer is registered. +func ConsumerFinalizer(f ...ConsumerFinalizerFunc) ConsumerOption { + return func(c *Consumer) { c.finalizer = f } +} + +// Consume calls ReceiveMessageWithContext and handles messages +// having receiveMsgInput as param allows each user to have his own receive config +func (c Consumer) Consume(ctx context.Context, receiveMsgInput *sqs.ReceiveMessageInput) error { + receiveMsgInput.QueueUrl = c.queueURL + out, err := c.sqsClient.ReceiveMessageWithContext(ctx, receiveMsgInput) + if err != nil { + return err + } + return c.HandleMessages(ctx, out.Messages) +} + +// HandleMessages handles the consumed messages +func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + // Copy msgs slice in leftMsgs + leftMsgs := []*sqs.Message{} + leftMsgs = append(leftMsgs, msgs...) + + // this func allows us to extend visibility timeout to give use + // time to process the messages in leftMsgs + go c.visibilityTimeoutFunc(ctx, c.sqsClient, c.queueURL, c.visibilityTimeout, &leftMsgs) + + if len(c.finalizer) > 0 { + defer func() { + for _, f := range c.finalizer { + f(ctx, &msgs) + } + }() + } + + for _, f := range c.before { + ctx = f(ctx, &msgs) + } + + for _, msg := range msgs { + if err := c.HandleSingleMessage(ctx, msg, &leftMsgs); err != nil { + return err + } + } + return nil +} + +// HandleSingleMessage handles a single sqs message +func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, leftMsgs *[]*sqs.Message) error { + req, err := c.dec(ctx, msg) + if err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + + response, err := c.e(ctx, req) + if err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + + responseMsg := sqs.SendMessageInput{} + for _, f := range c.after { + ctx = f(ctx, msg, &responseMsg, leftMsgs) + } + + if !c.wantRep(ctx, msg) { + // Message does not expect answer + return nil + } + + if err := c.enc(ctx, &responseMsg, response); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + + if _, err := c.sqsClient.SendMessageWithContext(ctx, &responseMsg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + return nil +} + +// ErrorEncoder is responsible for encoding an error to the consumer reply. +// Users are encouraged to use custom ErrorEncoders to encode errors to +// their replies, and will likely want to pass and check for their own error +// types. +type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient SQSClient) + +// ConsumerFinalizerFunc can be used to perform work at the end of a request +// from a publisher, after the response has been written to the publisher. The +// principal intended use is for request logging. +// Can also be used to delete messages once fully proccessed +type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) + +// DefaultErrorEncoder simply ignores the message. It does not reply +// nor Ack/Nack the message. +func DefaultErrorEncoder(context.Context, error, *sqs.Message, SQSClient) { +} + +// DoNotExtendVisibilityTimeout is the default value for visibilityTimeoutFunc +// It returns no error and does nothing +func DoNotExtendVisibilityTimeout(context.Context, SQSClient, *string, int64, *[]*sqs.Message) error { + return nil +} + +// EncodeJSONResponse marshals response as json and loads it into input MessageBody +func EncodeJSONResponse(_ context.Context, input *sqs.SendMessageInput, response interface{}) error { + payload, err := json.Marshal(response) + if err != nil { + return err + } + input.MessageBody = aws.String(string(payload)) + return nil +} + +// SQSClient is an interface to make testing possible. +// It is highly recommended to use *sqs.SQS as the interface implementation. +type SQSClient interface { + SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) + ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) + ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) +} diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go new file mode 100644 index 000000000..34d0ed1ac --- /dev/null +++ b/transport/awssqs/consumer_test.go @@ -0,0 +1,540 @@ +package awssqs_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/go-kit/kit/transport/awssqs" + "github.com/pborman/uuid" +) + +var ( + errTypeAssertion = errors.New("type assertion error") +) + +func (mock *mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) { + // Add logic to allow context errors + for { + select { + case d := <-mock.receiveOuputChan: + return d, mock.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +// TestConsumerBadDecode checks if decoder errors are handled properly. +func TestConsumerBadDecode(t *testing.T) { + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }, + }, + } + }() + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.SQSClient) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, errors.New("err!") }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, *sqs.Message) bool { return true }, + &queueURL, &dlQueueURL, int64(5), + errEncoder, + ) + + consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerBadEndpoint checks if endpoint errors are handled properly. +func TestConsumerBadEndpoint(t *testing.T) { + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }, + }, + } + }() + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.SQSClient) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("err!") }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, *sqs.Message) bool { return true }, + &queueURL, &dlQueueURL, int64(5), + errEncoder, + ) + + consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerBadEncoder checks if encoder errors are handled properly. +func TestConsumerBadEncoder(t *testing.T) { + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }, + }, + } + }() + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.SQSClient) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, + func(context.Context, *sqs.Message) bool { return true }, + &queueURL, &dlQueueURL, int64(5), + errEncoder, + ) + + consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestConsumerSuccess checks if consumer responds correctly to message. +func TestConsumerSuccess(t *testing.T) { + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }, + }, + } + }() + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + // wants response + func(context.Context, *sqs.Message) bool { return true }, + // queue, dlqueue, vibilityTimeout + &queueURL, &dlQueueURL, int64(5), + ) + + consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeResponse(receiveOutput) + if err != nil { + t.Fatal(err) + } + want := testRes{ + Squadron: 436, + Name: "tusker", + } + if have := res; want != have { + t.Errorf("want %v, have %v", want, have) + } +} + +// TestConsumerSuccessNoReply checks if consumer processes correctly message +// without sending response +func TestConsumerSuccessNoReply(t *testing.T) { + obj := testReq{ + Squadron: 436, + } + b, err := json.Marshal(obj) + if err != nil { + t.Fatal(err) + } + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }, + }, + } + }() + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + // wants response + func(context.Context, *sqs.Message) bool { return false }, + // queue, dlqueue, vibilityTimeout + &queueURL, &dlQueueURL, int64(5), + ) + + consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + t.Errorf("received output when none was expected, have %v", receiveOutput) + return + + case <-time.After(200 * time.Millisecond): + // As expected, we did not receive any response from consumer + return + } +} + +// TestConsumerBeforeFilterMessages checks if consumer before is called as expected. +// Here before is used to filter messages before processing +func TestConsumerBeforeFilterMessages(t *testing.T) { + obj1 := testReq{ + Squadron: 436, + } + b1, _ := json.Marshal(obj1) + obj2 := testReq{ + Squadron: 4, + } + b2, _ := json.Marshal(obj2) + obj3 := testReq{ + Squadron: 1, + } + b3, _ := json.Marshal(obj3) + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + expectedMsgs := []*sqs.Message{ + { + Body: aws.String(string(b1)), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "recipient": { + DataType: aws.String("String"), + StringValue: aws.String("me"), + }, + }, + }, + { + Body: aws.String(string(b2)), + MessageId: aws.String("fakeMsgID2"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "recipient": { + DataType: aws.String("String"), + StringValue: aws.String("not me"), + }, + }, + }, + { + Body: aws.String(string(b3)), + MessageId: aws.String("fakeMsgID3"), + }, + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: expectedMsgs, + } + }() + type ctxKey struct { + key string + } + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + // wants response + func(context.Context, *sqs.Message) bool { return true }, + // queue, dlqueue, vibilityTimeout + &queueURL, &dlQueueURL, int64(5), + awssqs.ConsumerBefore(func(ctx context.Context, msgs *[]*sqs.Message) context.Context { + // delete a message that is not destined to the consumer + msgsCopy := *msgs + for index, msg := range *msgs { + if recipient, exists := msg.MessageAttributes["recipient"]; !exists || *recipient.StringValue != "me" { + msgsCopy = append(msgsCopy[:index], msgsCopy[index:]...) + } + } + *msgs = msgsCopy + return ctx + }), + ) + ctx := context.Background() + consumer.Consume(ctx, &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeResponse(receiveOutput) + if err != nil { + t.Fatal(err) + } + want := testRes{ + Squadron: 436, + Name: "tusker", + } + if have := res; want != have { + t.Errorf("want %v, have %v", want, have) + } + // Try fetching responses again + select { + case receiveOutput = <-mock.receiveOuputChan: + t.Errorf("received second output when only one was expected, have %v", receiveOutput) + return + + case <-time.After(200 * time.Millisecond): + // As expected, we did not receive a second response from consumer + return + } +} + +// TestConsumerAfter checks if consumer after is called as expected. +// Here after is used to transfer some info from received message in response +func TestConsumerAfter(t *testing.T) { + obj1 := testReq{ + Squadron: 436, + } + b1, _ := json.Marshal(obj1) + queueURL := "someURL" + dlQueueURL := "somedlURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + } + correlationID := uuid.NewRandom().String() + expectedMsgs := []*sqs.Message{ + { + Body: aws.String(string(b1)), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "correlationID": { + DataType: aws.String("String"), + StringValue: &correlationID, + }, + }, + }, + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: expectedMsgs, + } + }() + type ctxKey struct { + key string + } + consumer := awssqs.NewConsumer(mock, + testEndpoint, + testReqDecoderfunc, + awssqs.EncodeJSONResponse, + // wants response + func(context.Context, *sqs.Message) bool { return true }, + // queue, dlqueue, vibilityTimeout + &queueURL, &dlQueueURL, int64(5), + awssqs.ConsumerAfter(func(ctx context.Context, msg *sqs.Message, resp *sqs.SendMessageInput, leftMsgs *[]*sqs.Message) context.Context { + if correlationIDAttribute, exists := msg.MessageAttributes["correlationID"]; exists { + if resp.MessageAttributes == nil { + resp.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) + } + resp.MessageAttributes["correlationID"] = &sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: correlationIDAttribute.StringValue, + } + } + return ctx + }), + ) + ctx := context.Background() + consumer.Consume(ctx, &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + if len(receiveOutput.Messages) != 1 { + t.Errorf("received %d messages instead of 1", len(receiveOutput.Messages)) + } + if correlationIDAttribute, exists := receiveOutput.Messages[0].MessageAttributes["correlationID"]; exists { + if have := correlationIDAttribute.StringValue; *have != correlationID { + t.Errorf("have %s, want %s", *have, correlationID) + } + } else { + t.Errorf("expected message attribute with key correlationID in response, but it was not found") + } +} + +type sqsError struct { + Err string `json:"err"` + MsgID string `json:"msgID"` +} + +func decodeConsumerError(receiveOutput *sqs.ReceiveMessageOutput) (sqsError, error) { + receivedError := sqsError{} + err := json.Unmarshal([]byte(*receiveOutput.Messages[0].Body), &receivedError) + return receivedError, err +} + +func testEndpoint(ctx context.Context, request interface{}) (interface{}, error) { + req, ok := request.(testReq) + if !ok { + return nil, errTypeAssertion + } + name, prs := names[req.Squadron] + if !prs { + return nil, errors.New("unknown squadron name") + } + res := testRes{ + Squadron: req.Squadron, + Name: name, + } + return res, nil +} + +func testReqDecoderfunc(_ context.Context, msg *sqs.Message) (interface{}, error) { + var obj testReq + err := json.Unmarshal([]byte(*msg.Body), &obj) + return obj, err +} + +func decodeResponse(receiveOutput *sqs.ReceiveMessageOutput) (interface{}, error) { + if len(receiveOutput.Messages) != 1 { + return nil, fmt.Errorf("Error : received %d messages instead of 1", len(receiveOutput.Messages)) + } + resp := testRes{} + err := json.Unmarshal([]byte(*receiveOutput.Messages[0].Body), &resp) + return resp, err +} diff --git a/transport/awssqs/encode_decode.go b/transport/awssqs/encode_decode.go new file mode 100644 index 000000000..28c284602 --- /dev/null +++ b/transport/awssqs/encode_decode.go @@ -0,0 +1,23 @@ +package awssqs + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/sqs" +) + +// DecodeRequestFunc extracts a user-domain request object from +// an sqs message object. It is designed to be used in sqs Subscribers. +type DecodeRequestFunc func(context.Context, *sqs.Message) (request interface{}, err error) + +// EncodeRequestFunc encodes the passed payload object into +// an sqs message object. It is designed to be used in sqs Publishers. +type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) error + +// EncodeResponseFunc encodes the passed response object to +// an sqs message object. It is designed to be used in sqs Subscribers. +type EncodeResponseFunc func(context.Context, *sqs.SendMessageInput, interface{}) error + +// DecodeResponseFunc extracts a user-domain response object from +// an sqs message object. It is designed to be used in sqs Publishers. +type DecodeResponseFunc func(context.Context, *sqs.Message) (response interface{}, err error) diff --git a/transport/awssqs/publisher.go b/transport/awssqs/publisher.go new file mode 100644 index 000000000..dfc6384b1 --- /dev/null +++ b/transport/awssqs/publisher.go @@ -0,0 +1,128 @@ +package awssqs + +import ( + "context" + "encoding/json" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/go-kit/kit/endpoint" +) + +// Publisher wraps an sqs client and queue, and provides a method that +// implements endpoint.Endpoint. +type Publisher struct { + sqsClient SQSClient + queueURL *string + responseQueueURL *string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []PublisherRequestFunc + after []PublisherResponseFunc + timeout time.Duration +} + +// NewPublisher constructs a usable Publisher for a single remote method. +func NewPublisher( + sqsClient SQSClient, + queueURL *string, + responseQueueURL *string, + enc EncodeRequestFunc, + dec DecodeResponseFunc, + options ...PublisherOption, +) *Publisher { + p := &Publisher{ + sqsClient: sqsClient, + queueURL: queueURL, + responseQueueURL: responseQueueURL, + enc: enc, + dec: dec, + timeout: 20 * time.Second, + } + for _, option := range options { + option(p) + } + return p +} + +// PublisherOption sets an optional parameter for clients. +type PublisherOption func(*Publisher) + +// PublisherBefore sets the RequestFuncs that are applied to the outgoing sqs +// request before it's invoked. +func PublisherBefore(before ...PublisherRequestFunc) PublisherOption { + return func(p *Publisher) { p.before = append(p.before, before...) } +} + +// PublisherAfter sets the ClientResponseFuncs applied to the incoming sqs +// request prior to it being decoded. This is useful for obtaining anything off +// of the response and adding onto the context prior to decoding. +func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { + return func(p *Publisher) { p.after = append(p.after, after...) } +} + +// PublisherTimeout sets the available timeout for an sqs request. +func PublisherTimeout(timeout time.Duration) PublisherOption { + return func(p *Publisher) { p.timeout = timeout } +} + +// Endpoint returns a usable endpoint that invokes the remote endpoint. +func (p Publisher) Endpoint() endpoint.Endpoint { + return func(ctx context.Context, request interface{}) (interface{}, error) { + ctx, cancel := context.WithTimeout(ctx, p.timeout) + defer cancel() + msgInput := sqs.SendMessageInput{ + QueueUrl: p.queueURL, + } + if err := p.enc(ctx, &msgInput, request); err != nil { + return nil, err + } + + for _, f := range p.before { + // Affect only msgInput + ctx = f(ctx, &msgInput) + } + + output, err := p.sqsClient.SendMessageWithContext(ctx, &msgInput) + if err != nil { + return nil, err + } + + var responseMsg *sqs.Message + for _, f := range p.after { + ctx, responseMsg, err = f(ctx, output) + if err != nil { + return nil, err + } + } + + response, err := p.dec(ctx, responseMsg) + if err != nil { + return nil, err + } + + return response, nil + } +} + +// EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a +// JSON object to the MessageBody of the Msg. +// Depending on your needs (if you don't need to add MessageAttributes or GroupID), +// this can be enough. +func EncodeJSONRequest(_ context.Context, msg *sqs.SendMessageInput, request interface{}) error { + b, err := json.Marshal(request) + if err != nil { + return err + } + + msg.MessageBody = aws.String(string(b)) + + return nil +} + +// NoResponseDecode is a DecodeResponseFunc that can be used when no response is needed. +// It returns nil value and nil error +func NoResponseDecode(_ context.Context, _ *sqs.Message) (interface{}, error) { + return nil, nil +} diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/publisher_test.go new file mode 100644 index 000000000..5b4c8a271 --- /dev/null +++ b/transport/awssqs/publisher_test.go @@ -0,0 +1,373 @@ +package awssqs_test + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "testing" + "time" + + "github.com/aws/aws-sdk-go/aws" + "github.com/aws/aws-sdk-go/aws/request" + "github.com/aws/aws-sdk-go/service/sqs" + "github.com/go-kit/kit/transport/awssqs" +) + +type testReq struct { + Squadron int `json:"s"` +} + +type testRes struct { + Squadron int `json:"s"` + Name string `json:"n"` +} + +var names = map[int]string{ + 424: "tiger", + 426: "thunderbird", + 429: "bison", + 436: "tusker", + 437: "husky", +} + +// mockSQSClient is a mock of *sqs.SQS. +type mockSQSClient struct { + err error + sendOutputChan chan *sqs.SendMessageOutput + receiveOuputChan chan *sqs.ReceiveMessageOutput + sendMsgID string +} + +func (mock *mockSQSClient) SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) { + if input != nil && input.MessageBody != nil && *input.MessageBody != "" { + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + MessageAttributes: input.MessageAttributes, + Body: input.MessageBody, + MessageId: aws.String(mock.sendMsgID), + }, + }, + } + }() + return &sqs.SendMessageOutput{MessageId: aws.String(mock.sendMsgID)}, nil + } + // Add logic to allow context errors + for { + select { + case d := <-mock.sendOutputChan: + return d, mock.err + case <-ctx.Done(): + return nil, ctx.Err() + } + } +} + +func (mock *mockSQSClient) ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) { + return nil, nil +} + +// TestBadEncode tests if encode errors are handled properly. +func TestBadEncode(t *testing.T) { + queueURL := "someURL" + responseQueueURL := "someOtherURL" + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + } + pub := awssqs.NewPublisher( + mock, + &queueURL, + &responseQueueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, + func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil }, + ) + errChan := make(chan error, 1) + var err error + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestBadDecode tests if decode errors are handled properly. +func TestBadDecode(t *testing.T) { + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + } + go func() { + mock.sendOutputChan <- &sqs.SendMessageOutput{ + MessageId: aws.String("someMsgID"), + } + }() + + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := awssqs.NewPublisher( + mock, + &queueURL, + &responseQueueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, *sqs.Message) (response interface{}, err error) { + return struct{}{}, errors.New("err!") + }, + awssqs.PublisherAfter(func(ctx context.Context, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + // Set the actual response for the request + return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil + }), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + }() + + select { + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("Timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + } + if want, have := "err!", err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +// TestPublisherTimeout ensures that the publisher timeout mechanism works. +func TestPublisherTimeout(t *testing.T) { + sendOutputChan := make(chan *sqs.SendMessageOutput) + mock := &mockSQSClient{ + sendOutputChan: sendOutputChan, + } + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := awssqs.NewPublisher( + mock, + &queueURL, + &responseQueueURL, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + func(context.Context, *sqs.Message) (response interface{}, err error) { + return struct{}{}, nil + }, + awssqs.PublisherTimeout(50*time.Millisecond), + ) + + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + errChan <- err + + }() + + select { + case err = <-errChan: + break + + case <-time.After(1000 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err == nil { + t.Error("expected error") + return + } + if want, have := context.DeadlineExceeded.Error(), err.Error(); want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSuccessfulPublisher(t *testing.T) { + mockReq := testReq{437} + mockRes := testRes{ + Squadron: mockReq.Squadron, + Name: names[mockReq.Squadron], + } + b, err := json.Marshal(mockRes) + if err != nil { + t.Fatal(err) + } + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + sendMsgID: "someMsgID", + } + go func() { + mock.sendOutputChan <- &sqs.SendMessageOutput{ + MessageId: aws.String("someMsgID"), + } + }() + + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := awssqs.NewPublisher( + mock, + &queueURL, + &responseQueueURL, + awssqs.EncodeJSONRequest, + func(_ context.Context, msg *sqs.Message) (interface{}, error) { + response := testRes{} + err := json.Unmarshal([]byte(*msg.Body), &response) + return response, err + }, + awssqs.PublisherAfter(func(ctx context.Context, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + // Sets the actual response for the request + if *msg.MessageId == "someMsgID" { + // Here should contain logic to consume msgs and check if response was provided + return ctx, &sqs.Message{Body: aws.String(string(b))}, nil + } + return nil, nil, fmt.Errorf("Did not receive expected SendMessageOutput") + }), + ) + var res testRes + var ok bool + resChan := make(chan interface{}, 1) + errChan := make(chan error, 1) + go func() { + res, err := pub.Endpoint()(context.Background(), mockReq) + if err != nil { + errChan <- err + } else { + resChan <- res + } + }() + + select { + case response := <-resChan: + res, ok = response.(testRes) + if !ok { + t.Error("failed to assert endpoint response type") + } + break + + case err = <-errChan: + break + + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } + + if err != nil { + t.Fatal(err) + } + if want, have := mockRes.Name, res.Name; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + +func TestSuccessfulPublisherNoResponse(t *testing.T) { + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + sendMsgID: "someMsgID", + } + + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := awssqs.NewPublisher( + mock, + &queueURL, + &responseQueueURL, + awssqs.EncodeJSONRequest, + awssqs.NoResponseDecode, + ) + var err error + errChan := make(chan error, 1) + finishChan := make(chan bool, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + if err != nil { + errChan <- err + } else { + finishChan <- true + } + }() + + select { + case <-finishChan: + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} + +func TestPublisherWithBefore(t *testing.T) { + mock := &mockSQSClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + sendMsgID: "someMsgID", + } + + queueURL := "someURL" + responseQueueURL := "someOtherURL" + pub := awssqs.NewPublisher( + mock, + &queueURL, + &responseQueueURL, + awssqs.EncodeJSONRequest, + awssqs.NoResponseDecode, + awssqs.PublisherBefore(func(c context.Context, s *sqs.SendMessageInput) context.Context { + if s.MessageAttributes == nil { + s.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) + } + s.MessageAttributes["responseQueueURL"] = &sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: &responseQueueURL, + } + return c + }), + ) + var err error + errChan := make(chan error, 1) + go func() { + _, err := pub.Endpoint()(context.Background(), struct{}{}) + if err != nil { + errChan <- err + } + }() + + want := sqs.MessageAttributeValue{ + DataType: aws.String("String"), + StringValue: &responseQueueURL, + } + + select { + case receiveOutput := <-mock.receiveOuputChan: + if len(receiveOutput.Messages) != 1 { + t.Errorf("published %d messages instead of 1", len(receiveOutput.Messages)) + } + if have, exists := receiveOutput.Messages[0].MessageAttributes["responseQueueURL"]; !exists { + t.Errorf("expected MessageAttributes responseQueueURL not found") + } else if *have.StringValue != responseQueueURL || *have.DataType != "String" { + t.Errorf("want %s, have %s", want, *have) + } + break + case err = <-errChan: + t.Errorf("unexpected error %s", err) + case <-time.After(100 * time.Millisecond): + t.Fatal("timed out waiting for result") + } +} diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go new file mode 100644 index 000000000..d9a1c5eb8 --- /dev/null +++ b/transport/awssqs/request_response_func.go @@ -0,0 +1,39 @@ +package awssqs + +import ( + "context" + + "github.com/aws/aws-sdk-go/service/sqs" +) + +// ConsumerRequestFunc may take information from a consumer request result and +// put it into a request context. In Consumers, RequestFuncs are executed prior +// to invoking the endpoint. +// use cases eg. in Consumer : extract message info to context, or sort messages +type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context + +// PublisherRequestFunc may take information from a publisher request and put it into a +// request context, or add some informations to SendMessageInput. In Publishers, +// RequestFuncs are executed prior to publishing the msg but after encoding. +// use cases eg. in Publisher : add message attributes to SendMessageInput +type PublisherRequestFunc func(context.Context, *sqs.SendMessageInput) context.Context + +// ConsumerResponseFunc may take information from a request context and use it to +// manipulate a Publisher. ConsumerResponseFunc are only executed in +// consumers, after invoking the endpoint but prior to publishing a reply. +// eg. Pipe information from req message to response MessageInput or delete msg from queue +// Should also delete message from leftMsgs slice +type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message) context.Context + +// PublisherResponseFunc may take information from an sqs send message output and +// ask for response. SQS is not req-reply out-of-the-box. Response needs to be fetched. +// PublisherResponseFunc are only executed in publishers, after a request has been made, +// but prior to its resp being decoded. So this is the perfect place to fetch actual response. +type PublisherResponseFunc func(context.Context, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) + +// WantReplyFunc encapsulates logic to check whether message awaits response or not +// eg. Check for a given attribute value +type WantReplyFunc func(context.Context, *sqs.Message) bool + +// VisibilityTimeoutFunc encapsulates logic to extend messages visibility timeout +type VisibilityTimeoutFunc func(context.Context, SQSClient, *string, int64, *[]*sqs.Message) error From 1ae7ac34269d4815a149dde2777465ed10ea382f Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 7 Oct 2020 12:21:15 +0200 Subject: [PATCH 02/18] Add doc.go --- transport/awssqs/doc.go | 2 ++ 1 file changed, 2 insertions(+) create mode 100644 transport/awssqs/doc.go diff --git a/transport/awssqs/doc.go b/transport/awssqs/doc.go new file mode 100644 index 000000000..b6b5355d5 --- /dev/null +++ b/transport/awssqs/doc.go @@ -0,0 +1,2 @@ +// Package awssqs implements an AWS Simple Query Service transport. +package awssqs From 66dc418890410f4853244b967916f6bb0ca0527f Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 7 Oct 2020 18:27:16 +0200 Subject: [PATCH 03/18] Review after first pull request pass --- transport/awssqs/consumer.go | 106 +++++++++++++--------- transport/awssqs/consumer_test.go | 68 ++++++-------- transport/awssqs/encode_decode.go | 8 +- transport/awssqs/publisher.go | 24 +++-- transport/awssqs/publisher_test.go | 59 ++++++------ transport/awssqs/request_response_func.go | 25 ++--- 6 files changed, 142 insertions(+), 148 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 37f40a71a..3d17ed926 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -12,15 +12,14 @@ import ( "github.com/go-kit/kit/transport" ) -// Consumer wraps an endpoint and provides and provides a handler for sqs msgs +// Consumer wraps an endpoint and provides a handler for sqs messages. type Consumer struct { - sqsClient SQSClient + sqsClient Client e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc wantRep WantReplyFunc - queueURL *string - dlQueueURL *string + queueURL string visibilityTimeout int64 visibilityTimeoutFunc VisibilityTimeoutFunc before []ConsumerRequestFunc @@ -33,14 +32,11 @@ type Consumer struct { // NewConsumer constructs a new Consumer, which provides a Consume method // and message handlers that wrap the provided endpoint. func NewConsumer( - sqsClient SQSClient, + sqsClient Client, e endpoint.Endpoint, dec DecodeRequestFunc, enc EncodeResponseFunc, - wantRep WantReplyFunc, - queueURL *string, - dlQueueURL *string, - visibilityTimeout int64, + queueURL string, options ...ConsumerOption, ) *Consumer { s := &Consumer{ @@ -48,10 +44,9 @@ func NewConsumer( e: e, dec: dec, enc: enc, - wantRep: wantRep, + wantRep: DoNotRespond, queueURL: queueURL, - dlQueueURL: dlQueueURL, - visibilityTimeout: visibilityTimeout, + visibilityTimeout: int64(30), visibilityTimeoutFunc: DoNotExtendVisibilityTimeout, errorEncoder: DefaultErrorEncoder, errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), @@ -86,20 +81,23 @@ func ConsumerErrorEncoder(ee ErrorEncoder) ConsumerOption { } // ConsumerVisbilityTimeOutFunc is used to extend the visibility timeout -// for messages during when processing them. Clients can -// use this to provide custom visibility timeout extension. By default, -// visibility timeout are not extend. +// for messages while the consumer processes them. +// VisibilityTimeoutFunc will need to check that the provided context is not done. +// By default, visibility timeout are not extended. func ConsumerVisbilityTimeOutFunc(vtFunc VisibilityTimeoutFunc) ConsumerOption { return func(c *Consumer) { c.visibilityTimeoutFunc = vtFunc } } -// ConsumerErrorLogger is used to log non-terminal errors. By default, no errors -// are logged. This is intended as a diagnostic measure. Finer-grained control -// of error handling, including logging in more detail, should be performed in a -// custom ConsumerErrorEncoder which has access to the context. -// Deprecated: Use ConsumerErrorHandler instead. -func ConsumerErrorLogger(logger log.Logger) ConsumerOption { - return func(c *Consumer) { c.errorHandler = transport.NewLogErrorHandler(logger) } +// ConsumerVisibilityTimeout overrides the default value for the consumer's +// visibilityTimeout field. +func ConsumerVisibilityTimeout(visibilityTimeout int64) ConsumerOption { + return func(c *Consumer) { c.visibilityTimeout = visibilityTimeout } +} + +// ConsumerWantReplyFunc overrides the default value for the consumer's +// wantRep field. +func ConsumerWantReplyFunc(replyFunc WantReplyFunc) ConsumerOption { + return func(c *Consumer) { c.wantRep = replyFunc } } // ConsumerErrorHandler is used to handle non-terminal errors. By default, non-terminal errors @@ -110,16 +108,18 @@ func ConsumerErrorHandler(errorHandler transport.ErrorHandler) ConsumerOption { return func(c *Consumer) { c.errorHandler = errorHandler } } -// ConsumerFinalizer is executed at the end of every request from a publisher through SQS. +// ConsumerFinalizer is executed once all the received SQS messages are done being processed. // By default, no finalizer is registered. func ConsumerFinalizer(f ...ConsumerFinalizerFunc) ConsumerOption { return func(c *Consumer) { c.finalizer = f } } -// Consume calls ReceiveMessageWithContext and handles messages -// having receiveMsgInput as param allows each user to have his own receive config +// Consume calls ReceiveMessageWithContext and handles messages having an +// sqs.ReceiveMessageInput as parameter allows each user to have his own receive configuration. +// That said, this method overrides the queueURL for the provided ReceiveMessageInput to ensure +// the messages are retrieved from the consumer's configured queue. func (c Consumer) Consume(ctx context.Context, receiveMsgInput *sqs.ReceiveMessageInput) error { - receiveMsgInput.QueueUrl = c.queueURL + receiveMsgInput.QueueUrl = &c.queueURL out, err := c.sqsClient.ReceiveMessageWithContext(ctx, receiveMsgInput) if err != nil { return err @@ -127,18 +127,20 @@ func (c Consumer) Consume(ctx context.Context, receiveMsgInput *sqs.ReceiveMessa return c.HandleMessages(ctx, out.Messages) } -// HandleMessages handles the consumed messages +// HandleMessages handles the consumed messages. func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - // Copy msgs slice in leftMsgs + // Copy received messages slice in leftMsgs slice + // leftMsgs will be used by the consumer's visibilityTimeoutFunc to extend the + // visibility timeout for the messages that have not been processed yet. leftMsgs := []*sqs.Message{} leftMsgs = append(leftMsgs, msgs...) - // this func allows us to extend visibility timeout to give use - // time to process the messages in leftMsgs - go c.visibilityTimeoutFunc(ctx, c.sqsClient, c.queueURL, c.visibilityTimeout, &leftMsgs) + visibilityTimeoutCtx, cancel := context.WithCancel(ctx) + defer cancel() + go c.visibilityTimeoutFunc(visibilityTimeoutCtx, c.sqsClient, c.queueURL, c.visibilityTimeout, &leftMsgs) if len(c.finalizer) > 0 { defer func() { @@ -160,7 +162,7 @@ func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error return nil } -// HandleSingleMessage handles a single sqs message +// HandleSingleMessage handles a single sqs message. func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, leftMsgs *[]*sqs.Message) error { req, err := c.dec(ctx, msg) if err != nil { @@ -182,7 +184,6 @@ func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, lef } if !c.wantRep(ctx, msg) { - // Message does not expect answer return nil } @@ -200,30 +201,45 @@ func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, lef return nil } -// ErrorEncoder is responsible for encoding an error to the consumer reply. +// ErrorEncoder is responsible for encoding an error to the consumer's reply. // Users are encouraged to use custom ErrorEncoders to encode errors to // their replies, and will likely want to pass and check for their own error // types. -type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient SQSClient) +type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient Client) // ConsumerFinalizerFunc can be used to perform work at the end of a request // from a publisher, after the response has been written to the publisher. The // principal intended use is for request logging. -// Can also be used to delete messages once fully proccessed +// Can also be used to delete messages once fully proccessed. type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) -// DefaultErrorEncoder simply ignores the message. It does not reply -// nor Ack/Nack the message. -func DefaultErrorEncoder(context.Context, error, *sqs.Message, SQSClient) { +// VisibilityTimeoutFunc encapsulates logic to extend messages visibility timeout. +// this can be used to provide custom visibility timeout extension such as doubling it everytime +// it gets close to being reached. +// VisibilityTimeoutFunc will need to check that the provided context is not done and return once it is. +type VisibilityTimeoutFunc func(context.Context, Client, string, int64, *[]*sqs.Message) error + +// WantReplyFunc encapsulates logic to check whether message awaits response or not +// for example check for a given message attribute value. +type WantReplyFunc func(context.Context, *sqs.Message) bool + +// DefaultErrorEncoder simply ignores the message. It does not reply. +func DefaultErrorEncoder(context.Context, error, *sqs.Message, Client) { } -// DoNotExtendVisibilityTimeout is the default value for visibilityTimeoutFunc +// DoNotExtendVisibilityTimeout is the default value for the consumer's visibilityTimeoutFunc. // It returns no error and does nothing -func DoNotExtendVisibilityTimeout(context.Context, SQSClient, *string, int64, *[]*sqs.Message) error { +func DoNotExtendVisibilityTimeout(context.Context, Client, string, int64, *[]*sqs.Message) error { return nil } -// EncodeJSONResponse marshals response as json and loads it into input MessageBody +// DoNotRespond is a WantReplyFunc and is the default value for consumer's wantRep field. +// It indicates that the message do not expect a response. +func DoNotRespond(context.Context, *sqs.Message) bool { + return false +} + +// EncodeJSONResponse marshals response as json and loads it into an sqs.SendMessageInput MessageBody. func EncodeJSONResponse(_ context.Context, input *sqs.SendMessageInput, response interface{}) error { payload, err := json.Marshal(response) if err != nil { @@ -233,9 +249,9 @@ func EncodeJSONResponse(_ context.Context, input *sqs.SendMessageInput, response return nil } -// SQSClient is an interface to make testing possible. -// It is highly recommended to use *sqs.SQS as the interface implementation. -type SQSClient interface { +// Client is consumer contract for the Producer and Consumer. +// It models methods of the AWS *sqs.SQS type. +type Client interface { SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 34d0ed1ac..5e3d65161 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -19,7 +19,7 @@ var ( errTypeAssertion = errors.New("type assertion error") ) -func (mock *mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) { +func (mock *mockClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) { // Add logic to allow context errors for { select { @@ -34,8 +34,7 @@ func (mock *mockSQSClient) ReceiveMessageWithContext(ctx context.Context, input // TestConsumerBadDecode checks if decoder errors are handled properly. func TestConsumerBadDecode(t *testing.T) { queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -49,7 +48,7 @@ func TestConsumerBadDecode(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.SQSClient) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, @@ -64,9 +63,9 @@ func TestConsumerBadDecode(t *testing.T) { func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, func(context.Context, *sqs.Message) (interface{}, error) { return nil, errors.New("err!") }, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, - func(context.Context, *sqs.Message) bool { return true }, - &queueURL, &dlQueueURL, int64(5), + queueURL, errEncoder, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) @@ -91,8 +90,7 @@ func TestConsumerBadDecode(t *testing.T) { // TestConsumerBadEndpoint checks if endpoint errors are handled properly. func TestConsumerBadEndpoint(t *testing.T) { queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -106,7 +104,7 @@ func TestConsumerBadEndpoint(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.SQSClient) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, @@ -121,9 +119,9 @@ func TestConsumerBadEndpoint(t *testing.T) { func(context.Context, interface{}) (interface{}, error) { return struct{}{}, errors.New("err!") }, func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, - func(context.Context, *sqs.Message) bool { return true }, - &queueURL, &dlQueueURL, int64(5), + queueURL, errEncoder, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) @@ -148,8 +146,7 @@ func TestConsumerBadEndpoint(t *testing.T) { // TestConsumerBadEncoder checks if encoder errors are handled properly. func TestConsumerBadEncoder(t *testing.T) { queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -163,7 +160,7 @@ func TestConsumerBadEncoder(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.SQSClient) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, @@ -178,9 +175,9 @@ func TestConsumerBadEncoder(t *testing.T) { func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, - func(context.Context, *sqs.Message) bool { return true }, - &queueURL, &dlQueueURL, int64(5), + queueURL, errEncoder, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) @@ -212,8 +209,7 @@ func TestConsumerSuccess(t *testing.T) { t.Fatal(err) } queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -231,10 +227,8 @@ func TestConsumerSuccess(t *testing.T) { testEndpoint, testReqDecoderfunc, awssqs.EncodeJSONResponse, - // wants response - func(context.Context, *sqs.Message) bool { return true }, - // queue, dlqueue, vibilityTimeout - &queueURL, &dlQueueURL, int64(5), + queueURL, + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) @@ -261,7 +255,7 @@ func TestConsumerSuccess(t *testing.T) { } // TestConsumerSuccessNoReply checks if consumer processes correctly message -// without sending response +// without sending response. func TestConsumerSuccessNoReply(t *testing.T) { obj := testReq{ Squadron: 436, @@ -271,8 +265,7 @@ func TestConsumerSuccessNoReply(t *testing.T) { t.Fatal(err) } queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -290,10 +283,7 @@ func TestConsumerSuccessNoReply(t *testing.T) { testEndpoint, testReqDecoderfunc, awssqs.EncodeJSONResponse, - // wants response - func(context.Context, *sqs.Message) bool { return false }, - // queue, dlqueue, vibilityTimeout - &queueURL, &dlQueueURL, int64(5), + queueURL, ) consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) @@ -326,8 +316,7 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { } b3, _ := json.Marshal(obj3) queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -369,10 +358,7 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { testEndpoint, testReqDecoderfunc, awssqs.EncodeJSONResponse, - // wants response - func(context.Context, *sqs.Message) bool { return true }, - // queue, dlqueue, vibilityTimeout - &queueURL, &dlQueueURL, int64(5), + queueURL, awssqs.ConsumerBefore(func(ctx context.Context, msgs *[]*sqs.Message) context.Context { // delete a message that is not destined to the consumer msgsCopy := *msgs @@ -384,6 +370,7 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { *msgs = msgsCopy return ctx }), + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) ctx := context.Background() consumer.Consume(ctx, &sqs.ReceiveMessageInput{}) @@ -420,15 +407,14 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { } // TestConsumerAfter checks if consumer after is called as expected. -// Here after is used to transfer some info from received message in response +// Here after is used to transfer some info from received message in response. func TestConsumerAfter(t *testing.T) { obj1 := testReq{ Squadron: 436, } b1, _ := json.Marshal(obj1) queueURL := "someURL" - dlQueueURL := "somedlURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } @@ -457,10 +443,7 @@ func TestConsumerAfter(t *testing.T) { testEndpoint, testReqDecoderfunc, awssqs.EncodeJSONResponse, - // wants response - func(context.Context, *sqs.Message) bool { return true }, - // queue, dlqueue, vibilityTimeout - &queueURL, &dlQueueURL, int64(5), + queueURL, awssqs.ConsumerAfter(func(ctx context.Context, msg *sqs.Message, resp *sqs.SendMessageInput, leftMsgs *[]*sqs.Message) context.Context { if correlationIDAttribute, exists := msg.MessageAttributes["correlationID"]; exists { if resp.MessageAttributes == nil { @@ -473,6 +456,7 @@ func TestConsumerAfter(t *testing.T) { } return ctx }), + awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) ctx := context.Background() consumer.Consume(ctx, &sqs.ReceiveMessageInput{}) diff --git a/transport/awssqs/encode_decode.go b/transport/awssqs/encode_decode.go index 28c284602..654700cff 100644 --- a/transport/awssqs/encode_decode.go +++ b/transport/awssqs/encode_decode.go @@ -7,17 +7,17 @@ import ( ) // DecodeRequestFunc extracts a user-domain request object from -// an sqs message object. It is designed to be used in sqs Subscribers. +// an sqs message object. It is designed to be used in Consumers. type DecodeRequestFunc func(context.Context, *sqs.Message) (request interface{}, err error) // EncodeRequestFunc encodes the passed payload object into -// an sqs message object. It is designed to be used in sqs Publishers. +// an sqs message object. It is designed to be used in Publishers. type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) error // EncodeResponseFunc encodes the passed response object to -// an sqs message object. It is designed to be used in sqs Subscribers. +// an sqs message object. It is designed to be used in Consumers. type EncodeResponseFunc func(context.Context, *sqs.SendMessageInput, interface{}) error // DecodeResponseFunc extracts a user-domain response object from -// an sqs message object. It is designed to be used in sqs Publishers. +// an sqs message object. It is designed to be used in Publishers. type DecodeResponseFunc func(context.Context, *sqs.Message) (response interface{}, err error) diff --git a/transport/awssqs/publisher.go b/transport/awssqs/publisher.go index dfc6384b1..273f0fae1 100644 --- a/transport/awssqs/publisher.go +++ b/transport/awssqs/publisher.go @@ -13,9 +13,9 @@ import ( // Publisher wraps an sqs client and queue, and provides a method that // implements endpoint.Endpoint. type Publisher struct { - sqsClient SQSClient - queueURL *string - responseQueueURL *string + sqsClient Client + queueURL string + responseQueueURL string enc EncodeRequestFunc dec DecodeResponseFunc before []PublisherRequestFunc @@ -25,9 +25,9 @@ type Publisher struct { // NewPublisher constructs a usable Publisher for a single remote method. func NewPublisher( - sqsClient SQSClient, - queueURL *string, - responseQueueURL *string, + sqsClient Client, + queueURL string, + responseQueueURL string, enc EncodeRequestFunc, dec DecodeResponseFunc, options ...PublisherOption, @@ -73,14 +73,13 @@ func (p Publisher) Endpoint() endpoint.Endpoint { ctx, cancel := context.WithTimeout(ctx, p.timeout) defer cancel() msgInput := sqs.SendMessageInput{ - QueueUrl: p.queueURL, + QueueUrl: &p.queueURL, } if err := p.enc(ctx, &msgInput, request); err != nil { return nil, err } for _, f := range p.before { - // Affect only msgInput ctx = f(ctx, &msgInput) } @@ -91,7 +90,7 @@ func (p Publisher) Endpoint() endpoint.Endpoint { var responseMsg *sqs.Message for _, f := range p.after { - ctx, responseMsg, err = f(ctx, output) + ctx, responseMsg, err = f(ctx, p.sqsClient, output) if err != nil { return nil, err } @@ -107,9 +106,8 @@ func (p Publisher) Endpoint() endpoint.Endpoint { } // EncodeJSONRequest is an EncodeRequestFunc that serializes the request as a -// JSON object to the MessageBody of the Msg. -// Depending on your needs (if you don't need to add MessageAttributes or GroupID), -// this can be enough. +// JSON object and loads it as the MessageBody of the sqs.SendMessageInput. +// This can be enough for most JSON over SQS communications. func EncodeJSONRequest(_ context.Context, msg *sqs.SendMessageInput, request interface{}) error { b, err := json.Marshal(request) if err != nil { @@ -122,7 +120,7 @@ func EncodeJSONRequest(_ context.Context, msg *sqs.SendMessageInput, request int } // NoResponseDecode is a DecodeResponseFunc that can be used when no response is needed. -// It returns nil value and nil error +// It returns nil value and nil error. func NoResponseDecode(_ context.Context, _ *sqs.Message) (interface{}, error) { return nil, nil } diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/publisher_test.go index 5b4c8a271..0a438a8a9 100644 --- a/transport/awssqs/publisher_test.go +++ b/transport/awssqs/publisher_test.go @@ -31,15 +31,15 @@ var names = map[int]string{ 437: "husky", } -// mockSQSClient is a mock of *sqs.SQS. -type mockSQSClient struct { +// mockClient is a mock of *sqs.SQS. +type mockClient struct { err error sendOutputChan chan *sqs.SendMessageOutput receiveOuputChan chan *sqs.ReceiveMessageOutput sendMsgID string } -func (mock *mockSQSClient) SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) { +func (mock *mockClient) SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) { if input != nil && input.MessageBody != nil && *input.MessageBody != "" { go func() { mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ @@ -54,7 +54,7 @@ func (mock *mockSQSClient) SendMessageWithContext(ctx context.Context, input *sq }() return &sqs.SendMessageOutput{MessageId: aws.String(mock.sendMsgID)}, nil } - // Add logic to allow context errors + // Add logic to allow context errors. for { select { case d := <-mock.sendOutputChan: @@ -65,7 +65,7 @@ func (mock *mockSQSClient) SendMessageWithContext(ctx context.Context, input *sq } } -func (mock *mockSQSClient) ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) { +func (mock *mockClient) ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) { return nil, nil } @@ -73,13 +73,13 @@ func (mock *mockSQSClient) ChangeMessageVisibilityWithContext(ctx aws.Context, i func TestBadEncode(t *testing.T) { queueURL := "someURL" responseQueueURL := "someOtherURL" - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), } pub := awssqs.NewPublisher( mock, - &queueURL, - &responseQueueURL, + queueURL, + responseQueueURL, func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil }, ) @@ -107,7 +107,7 @@ func TestBadEncode(t *testing.T) { // TestBadDecode tests if decode errors are handled properly. func TestBadDecode(t *testing.T) { - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), } go func() { @@ -120,14 +120,14 @@ func TestBadDecode(t *testing.T) { responseQueueURL := "someOtherURL" pub := awssqs.NewPublisher( mock, - &queueURL, - &responseQueueURL, + queueURL, + responseQueueURL, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, errors.New("err!") }, - awssqs.PublisherAfter(func(ctx context.Context, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { - // Set the actual response for the request + awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + // Set the actual response for the request. return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil }), ) @@ -158,15 +158,15 @@ func TestBadDecode(t *testing.T) { // TestPublisherTimeout ensures that the publisher timeout mechanism works. func TestPublisherTimeout(t *testing.T) { sendOutputChan := make(chan *sqs.SendMessageOutput) - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: sendOutputChan, } queueURL := "someURL" responseQueueURL := "someOtherURL" pub := awssqs.NewPublisher( mock, - &queueURL, - &responseQueueURL, + queueURL, + responseQueueURL, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil @@ -199,6 +199,7 @@ func TestPublisherTimeout(t *testing.T) { } } +// TestSuccessfulPublisher ensures that the publisher mechanisms work. func TestSuccessfulPublisher(t *testing.T) { mockReq := testReq{437} mockRes := testRes{ @@ -209,7 +210,7 @@ func TestSuccessfulPublisher(t *testing.T) { if err != nil { t.Fatal(err) } - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), sendMsgID: "someMsgID", } @@ -223,18 +224,17 @@ func TestSuccessfulPublisher(t *testing.T) { responseQueueURL := "someOtherURL" pub := awssqs.NewPublisher( mock, - &queueURL, - &responseQueueURL, + queueURL, + responseQueueURL, awssqs.EncodeJSONRequest, func(_ context.Context, msg *sqs.Message) (interface{}, error) { response := testRes{} err := json.Unmarshal([]byte(*msg.Body), &response) return response, err }, - awssqs.PublisherAfter(func(ctx context.Context, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { - // Sets the actual response for the request + awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + // Sets the actual response for the request. if *msg.MessageId == "someMsgID" { - // Here should contain logic to consume msgs and check if response was provided return ctx, &sqs.Message{Body: aws.String(string(b))}, nil } return nil, nil, fmt.Errorf("Did not receive expected SendMessageOutput") @@ -276,8 +276,9 @@ func TestSuccessfulPublisher(t *testing.T) { } } +// TestSuccessfulPublisherNoResponse ensures that the publisher response mechanism works. func TestSuccessfulPublisherNoResponse(t *testing.T) { - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), sendMsgID: "someMsgID", @@ -287,8 +288,8 @@ func TestSuccessfulPublisherNoResponse(t *testing.T) { responseQueueURL := "someOtherURL" pub := awssqs.NewPublisher( mock, - &queueURL, - &responseQueueURL, + queueURL, + responseQueueURL, awssqs.EncodeJSONRequest, awssqs.NoResponseDecode, ) @@ -314,8 +315,10 @@ func TestSuccessfulPublisherNoResponse(t *testing.T) { } } +// TestPublisherWithBefore adds a PublisherBefore function that adds a message attribute. +// This test ensures that the the before functions work as expected. func TestPublisherWithBefore(t *testing.T) { - mock := &mockSQSClient{ + mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), sendMsgID: "someMsgID", @@ -325,8 +328,8 @@ func TestPublisherWithBefore(t *testing.T) { responseQueueURL := "someOtherURL" pub := awssqs.NewPublisher( mock, - &queueURL, - &responseQueueURL, + queueURL, + responseQueueURL, awssqs.EncodeJSONRequest, awssqs.NoResponseDecode, awssqs.PublisherBefore(func(c context.Context, s *sqs.SendMessageInput) context.Context { diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index d9a1c5eb8..5e29d1ce5 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -9,31 +9,24 @@ import ( // ConsumerRequestFunc may take information from a consumer request result and // put it into a request context. In Consumers, RequestFuncs are executed prior // to invoking the endpoint. -// use cases eg. in Consumer : extract message info to context, or sort messages +// use cases eg. in Consumer : extract message into context, or filter received messages. type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context // PublisherRequestFunc may take information from a publisher request and put it into a // request context, or add some informations to SendMessageInput. In Publishers, -// RequestFuncs are executed prior to publishing the msg but after encoding. -// use cases eg. in Publisher : add message attributes to SendMessageInput +// RequestFuncs are executed prior to publishing the message but after encoding. +// use cases eg. in Publisher : enforce some message attributes to SendMessageInput type PublisherRequestFunc func(context.Context, *sqs.SendMessageInput) context.Context // ConsumerResponseFunc may take information from a request context and use it to // manipulate a Publisher. ConsumerResponseFunc are only executed in // consumers, after invoking the endpoint but prior to publishing a reply. -// eg. Pipe information from req message to response MessageInput or delete msg from queue -// Should also delete message from leftMsgs slice +// use cases eg. : Pipe information from request message to response MessageInput, +// delete msg from queue or update leftMsgs slice type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message) context.Context -// PublisherResponseFunc may take information from an sqs send message output and -// ask for response. SQS is not req-reply out-of-the-box. Response needs to be fetched. +// PublisherResponseFunc may take information from an sqs.SendMessageOutput and +// fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. // PublisherResponseFunc are only executed in publishers, after a request has been made, -// but prior to its resp being decoded. So this is the perfect place to fetch actual response. -type PublisherResponseFunc func(context.Context, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) - -// WantReplyFunc encapsulates logic to check whether message awaits response or not -// eg. Check for a given attribute value -type WantReplyFunc func(context.Context, *sqs.Message) bool - -// VisibilityTimeoutFunc encapsulates logic to extend messages visibility timeout -type VisibilityTimeoutFunc func(context.Context, SQSClient, *string, int64, *[]*sqs.Message) error +// but prior to its response being decoded. So this is the perfect place to fetch actual response. +type PublisherResponseFunc func(context.Context, Client, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) From 98eea25d2105d3e0def2172a92a4b8e313c67dea Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 7 Oct 2020 18:40:05 +0200 Subject: [PATCH 04/18] Add mutex to sync use of left messages var --- transport/awssqs/consumer.go | 11 +++++++---- transport/awssqs/consumer_test.go | 5 ++++- transport/awssqs/request_response_func.go | 3 ++- 3 files changed, 13 insertions(+), 6 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 3d17ed926..a0e1040be 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -3,6 +3,7 @@ package awssqs import ( "context" "encoding/json" + "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" @@ -22,6 +23,7 @@ type Consumer struct { queueURL string visibilityTimeout int64 visibilityTimeoutFunc VisibilityTimeoutFunc + leftMsgsMux *sync.Mutex before []ConsumerRequestFunc after []ConsumerResponseFunc errorEncoder ErrorEncoder @@ -48,6 +50,7 @@ func NewConsumer( queueURL: queueURL, visibilityTimeout: int64(30), visibilityTimeoutFunc: DoNotExtendVisibilityTimeout, + leftMsgsMux: &sync.Mutex{}, errorEncoder: DefaultErrorEncoder, errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), } @@ -140,7 +143,7 @@ func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error visibilityTimeoutCtx, cancel := context.WithCancel(ctx) defer cancel() - go c.visibilityTimeoutFunc(visibilityTimeoutCtx, c.sqsClient, c.queueURL, c.visibilityTimeout, &leftMsgs) + go c.visibilityTimeoutFunc(visibilityTimeoutCtx, c.sqsClient, c.queueURL, c.visibilityTimeout, &leftMsgs, c.leftMsgsMux) if len(c.finalizer) > 0 { defer func() { @@ -180,7 +183,7 @@ func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, lef responseMsg := sqs.SendMessageInput{} for _, f := range c.after { - ctx = f(ctx, msg, &responseMsg, leftMsgs) + ctx = f(ctx, msg, &responseMsg, leftMsgs, c.leftMsgsMux) } if !c.wantRep(ctx, msg) { @@ -217,7 +220,7 @@ type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) // this can be used to provide custom visibility timeout extension such as doubling it everytime // it gets close to being reached. // VisibilityTimeoutFunc will need to check that the provided context is not done and return once it is. -type VisibilityTimeoutFunc func(context.Context, Client, string, int64, *[]*sqs.Message) error +type VisibilityTimeoutFunc func(context.Context, Client, string, int64, *[]*sqs.Message, *sync.Mutex) error // WantReplyFunc encapsulates logic to check whether message awaits response or not // for example check for a given message attribute value. @@ -229,7 +232,7 @@ func DefaultErrorEncoder(context.Context, error, *sqs.Message, Client) { // DoNotExtendVisibilityTimeout is the default value for the consumer's visibilityTimeoutFunc. // It returns no error and does nothing -func DoNotExtendVisibilityTimeout(context.Context, Client, string, int64, *[]*sqs.Message) error { +func DoNotExtendVisibilityTimeout(context.Context, Client, string, int64, *[]*sqs.Message, *sync.Mutex) error { return nil } diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 5e3d65161..05ff2b0cb 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -5,6 +5,7 @@ import ( "encoding/json" "errors" "fmt" + "sync" "testing" "time" @@ -444,7 +445,9 @@ func TestConsumerAfter(t *testing.T) { testReqDecoderfunc, awssqs.EncodeJSONResponse, queueURL, - awssqs.ConsumerAfter(func(ctx context.Context, msg *sqs.Message, resp *sqs.SendMessageInput, leftMsgs *[]*sqs.Message) context.Context { + awssqs.ConsumerAfter(func(ctx context.Context, msg *sqs.Message, resp *sqs.SendMessageInput, leftMsgs *[]*sqs.Message, mux *sync.Mutex) context.Context { + mux.Lock() + defer mux.Unlock() if correlationIDAttribute, exists := msg.MessageAttributes["correlationID"]; exists { if resp.MessageAttributes == nil { resp.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index 5e29d1ce5..cdc3030e7 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -2,6 +2,7 @@ package awssqs import ( "context" + "sync" "github.com/aws/aws-sdk-go/service/sqs" ) @@ -23,7 +24,7 @@ type PublisherRequestFunc func(context.Context, *sqs.SendMessageInput) context.C // consumers, after invoking the endpoint but prior to publishing a reply. // use cases eg. : Pipe information from request message to response MessageInput, // delete msg from queue or update leftMsgs slice -type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message) context.Context +type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message, *sync.Mutex) context.Context // PublisherResponseFunc may take information from an sqs.SendMessageOutput and // fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. From b8b823fa4e470faddbdf21464dabf11171982152 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 09:59:41 +0200 Subject: [PATCH 05/18] add publisher's responseQueueURL as parameter to before funcion --- transport/awssqs/publisher.go | 2 +- transport/awssqs/publisher_test.go | 2 +- transport/awssqs/request_response_func.go | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/transport/awssqs/publisher.go b/transport/awssqs/publisher.go index 273f0fae1..3bb0aeb52 100644 --- a/transport/awssqs/publisher.go +++ b/transport/awssqs/publisher.go @@ -80,7 +80,7 @@ func (p Publisher) Endpoint() endpoint.Endpoint { } for _, f := range p.before { - ctx = f(ctx, &msgInput) + ctx = f(ctx, &msgInput, p.responseQueueURL) } output, err := p.sqsClient.SendMessageWithContext(ctx, &msgInput) diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/publisher_test.go index 0a438a8a9..c2a434eb8 100644 --- a/transport/awssqs/publisher_test.go +++ b/transport/awssqs/publisher_test.go @@ -332,7 +332,7 @@ func TestPublisherWithBefore(t *testing.T) { responseQueueURL, awssqs.EncodeJSONRequest, awssqs.NoResponseDecode, - awssqs.PublisherBefore(func(c context.Context, s *sqs.SendMessageInput) context.Context { + awssqs.PublisherBefore(func(c context.Context, s *sqs.SendMessageInput, _ string) context.Context { if s.MessageAttributes == nil { s.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) } diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index cdc3030e7..a34c0d691 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -17,7 +17,7 @@ type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context // request context, or add some informations to SendMessageInput. In Publishers, // RequestFuncs are executed prior to publishing the message but after encoding. // use cases eg. in Publisher : enforce some message attributes to SendMessageInput -type PublisherRequestFunc func(context.Context, *sqs.SendMessageInput) context.Context +type PublisherRequestFunc func(context.Context, *sqs.SendMessageInput, string) context.Context // ConsumerResponseFunc may take information from a request context and use it to // manipulate a Publisher. ConsumerResponseFunc are only executed in From 7546abbfe88a757b0c49c510ee5502bbfaf692db Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 11:05:58 +0200 Subject: [PATCH 06/18] Add attribute in consumer indicating when to delete message --- transport/awssqs/consumer.go | 44 +++++++++++++++++++++ transport/awssqs/consumer_test.go | 61 ++++++++++++++++++++++++++++++ transport/awssqs/publisher_test.go | 1 + 3 files changed, 106 insertions(+) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index a0e1040be..4b2637f2f 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -13,6 +13,19 @@ import ( "github.com/go-kit/kit/transport" ) +// Delete is a type to indicate when the consumed message should be deleted +type Delete int + +const ( + // BeforeHandle deletes the message before starting to handle it. + BeforeHandle Delete = iota + // AfterHandle deletes the message once it has been fully processed. + // This is the consumer's default value. + AfterHandle + // Never does not delete the message. + Never +) + // Consumer wraps an endpoint and provides a handler for sqs messages. type Consumer struct { sqsClient Client @@ -29,6 +42,7 @@ type Consumer struct { errorEncoder ErrorEncoder finalizer []ConsumerFinalizerFunc errorHandler transport.ErrorHandler + deleteMessage Delete } // NewConsumer constructs a new Consumer, which provides a Consume method @@ -53,6 +67,7 @@ func NewConsumer( leftMsgsMux: &sync.Mutex{}, errorEncoder: DefaultErrorEncoder, errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), + deleteMessage: AfterHandle, } for _, option := range options { option(s) @@ -117,6 +132,12 @@ func ConsumerFinalizer(f ...ConsumerFinalizerFunc) ConsumerOption { return func(c *Consumer) { c.finalizer = f } } +// ConsumerDeleteMessage overrides the default value for the consumer's +// deleteMessage field to indicate when the consumed messages should be deleted. +func ConsumerDeleteMessage(delete Delete) ConsumerOption { + return func(c *Consumer) { c.deleteMessage = delete } +} + // Consume calls ReceiveMessageWithContext and handles messages having an // sqs.ReceiveMessageInput as parameter allows each user to have his own receive configuration. // That said, this method overrides the queueURL for the provided ReceiveMessageInput to ensure @@ -158,9 +179,31 @@ func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error } for _, msg := range msgs { + if c.deleteMessage == BeforeHandle { + if _, err := c.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + QueueUrl: &c.queueURL, + ReceiptHandle: msg.ReceiptHandle, + }); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + } + if err := c.HandleSingleMessage(ctx, msg, &leftMsgs); err != nil { return err } + + if c.deleteMessage == AfterHandle { + if _, err := c.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + QueueUrl: &c.queueURL, + ReceiptHandle: msg.ReceiptHandle, + }); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + } } return nil } @@ -258,4 +301,5 @@ type Client interface { SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) + DeleteMessageWithContext(ctx context.Context, input *sqs.DeleteMessageInput, opts ...request.Option) (*sqs.DeleteMessageOutput, error) } diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 05ff2b0cb..0251cc700 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -32,6 +32,67 @@ func (mock *mockClient) ReceiveMessageWithContext(ctx context.Context, input *sq } } +func (mock *mockClient) DeleteMessageWithContext(ctx context.Context, input *sqs.DeleteMessageInput, opts ...request.Option) (*sqs.DeleteMessageOutput, error) { + return nil, mock.deleteError +} + +// TestConsumerDeleteBefore checks if deleteMessage is set properly using consumer options. +func TestConsumerDeleteBefore(t *testing.T) { + queueURL := "someURL" + mock := &mockClient{ + sendOutputChan: make(chan *sqs.SendMessageOutput), + receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), + deleteError: fmt.Errorf("delete err!"), + } + go func() { + mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ + Messages: []*sqs.Message{ + { + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }, + }, + } + }() + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { + publishError := sqsError{ + Err: err.Error(), + MsgID: *req.MessageId, + } + payload, _ := json.Marshal(publishError) + + sqsClient.SendMessageWithContext(ctx, &sqs.SendMessageInput{ + MessageBody: aws.String(string(payload)), + }) + }) + consumer := awssqs.NewConsumer(mock, + func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, errors.New("decode err!") }, + func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, + queueURL, + errEncoder, + awssqs.ConsumerDeleteMessage(awssqs.BeforeHandle), + ) + + consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + + var receiveOutput *sqs.ReceiveMessageOutput + select { + case receiveOutput = <-mock.receiveOuputChan: + break + + case <-time.After(200 * time.Millisecond): + t.Fatal("Timed out waiting for publishing") + } + res, err := decodeConsumerError(receiveOutput) + if err != nil { + t.Fatal(err) + } + if want, have := "delete err!", res.Err; want != have { + t.Errorf("want %s, have %s", want, have) + } +} + // TestConsumerBadDecode checks if decoder errors are handled properly. func TestConsumerBadDecode(t *testing.T) { queueURL := "someURL" diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/publisher_test.go index c2a434eb8..0a829279b 100644 --- a/transport/awssqs/publisher_test.go +++ b/transport/awssqs/publisher_test.go @@ -37,6 +37,7 @@ type mockClient struct { sendOutputChan chan *sqs.SendMessageOutput receiveOuputChan chan *sqs.ReceiveMessageOutput sendMsgID string + deleteError error } func (mock *mockClient) SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) { From f839b88b6886a03817420166f158687e68cadfd9 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 11:14:11 +0200 Subject: [PATCH 07/18] Use responseQueueURL in publisher's after function to consume response --- transport/awssqs/publisher.go | 2 +- transport/awssqs/publisher_test.go | 4 ++-- transport/awssqs/request_response_func.go | 2 +- 3 files changed, 4 insertions(+), 4 deletions(-) diff --git a/transport/awssqs/publisher.go b/transport/awssqs/publisher.go index 3bb0aeb52..9fba6c6fe 100644 --- a/transport/awssqs/publisher.go +++ b/transport/awssqs/publisher.go @@ -90,7 +90,7 @@ func (p Publisher) Endpoint() endpoint.Endpoint { var responseMsg *sqs.Message for _, f := range p.after { - ctx, responseMsg, err = f(ctx, p.sqsClient, output) + ctx, responseMsg, err = f(ctx, p.sqsClient, p.responseQueueURL, output) if err != nil { return nil, err } diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/publisher_test.go index 0a829279b..849b6b4f9 100644 --- a/transport/awssqs/publisher_test.go +++ b/transport/awssqs/publisher_test.go @@ -127,7 +127,7 @@ func TestBadDecode(t *testing.T) { func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, errors.New("err!") }, - awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Set the actual response for the request. return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil }), @@ -233,7 +233,7 @@ func TestSuccessfulPublisher(t *testing.T) { err := json.Unmarshal([]byte(*msg.Body), &response) return response, err }, - awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Sets the actual response for the request. if *msg.MessageId == "someMsgID" { return ctx, &sqs.Message{Body: aws.String(string(b))}, nil diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index a34c0d691..e78d85b54 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -30,4 +30,4 @@ type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageIn // fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. // PublisherResponseFunc are only executed in publishers, after a request has been made, // but prior to its response being decoded. So this is the perfect place to fetch actual response. -type PublisherResponseFunc func(context.Context, Client, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) +type PublisherResponseFunc func(context.Context, Client, string, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) From ec32deb0ff99c060b28653d464f75ca73bf77c60 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 16:25:09 +0200 Subject: [PATCH 08/18] Use SQS' official API interface instead of homemade one --- transport/awssqs/consumer.go | 23 +++++++---------------- transport/awssqs/consumer_test.go | 9 +++++---- transport/awssqs/publisher.go | 5 +++-- transport/awssqs/publisher_test.go | 6 ++++-- transport/awssqs/request_response_func.go | 3 ++- 5 files changed, 21 insertions(+), 25 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 4b2637f2f..8d5548c2f 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -6,8 +6,8 @@ import ( "sync" "github.com/aws/aws-sdk-go/aws" - "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "github.com/go-kit/kit/endpoint" "github.com/go-kit/kit/log" "github.com/go-kit/kit/transport" @@ -28,7 +28,7 @@ const ( // Consumer wraps an endpoint and provides a handler for sqs messages. type Consumer struct { - sqsClient Client + sqsClient sqsiface.SQSAPI e endpoint.Endpoint dec DecodeRequestFunc enc EncodeResponseFunc @@ -48,7 +48,7 @@ type Consumer struct { // NewConsumer constructs a new Consumer, which provides a Consume method // and message handlers that wrap the provided endpoint. func NewConsumer( - sqsClient Client, + sqsClient sqsiface.SQSAPI, e endpoint.Endpoint, dec DecodeRequestFunc, enc EncodeResponseFunc, @@ -251,7 +251,7 @@ func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, lef // Users are encouraged to use custom ErrorEncoders to encode errors to // their replies, and will likely want to pass and check for their own error // types. -type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient Client) +type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) // ConsumerFinalizerFunc can be used to perform work at the end of a request // from a publisher, after the response has been written to the publisher. The @@ -263,19 +263,19 @@ type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) // this can be used to provide custom visibility timeout extension such as doubling it everytime // it gets close to being reached. // VisibilityTimeoutFunc will need to check that the provided context is not done and return once it is. -type VisibilityTimeoutFunc func(context.Context, Client, string, int64, *[]*sqs.Message, *sync.Mutex) error +type VisibilityTimeoutFunc func(context.Context, sqsiface.SQSAPI, string, int64, *[]*sqs.Message, *sync.Mutex) error // WantReplyFunc encapsulates logic to check whether message awaits response or not // for example check for a given message attribute value. type WantReplyFunc func(context.Context, *sqs.Message) bool // DefaultErrorEncoder simply ignores the message. It does not reply. -func DefaultErrorEncoder(context.Context, error, *sqs.Message, Client) { +func DefaultErrorEncoder(context.Context, error, *sqs.Message, sqsiface.SQSAPI) { } // DoNotExtendVisibilityTimeout is the default value for the consumer's visibilityTimeoutFunc. // It returns no error and does nothing -func DoNotExtendVisibilityTimeout(context.Context, Client, string, int64, *[]*sqs.Message, *sync.Mutex) error { +func DoNotExtendVisibilityTimeout(context.Context, sqsiface.SQSAPI, string, int64, *[]*sqs.Message, *sync.Mutex) error { return nil } @@ -294,12 +294,3 @@ func EncodeJSONResponse(_ context.Context, input *sqs.SendMessageInput, response input.MessageBody = aws.String(string(payload)) return nil } - -// Client is consumer contract for the Producer and Consumer. -// It models methods of the AWS *sqs.SQS type. -type Client interface { - SendMessageWithContext(ctx context.Context, input *sqs.SendMessageInput, opts ...request.Option) (*sqs.SendMessageOutput, error) - ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) - ChangeMessageVisibilityWithContext(ctx aws.Context, input *sqs.ChangeMessageVisibilityInput, opts ...request.Option) (*sqs.ChangeMessageVisibilityOutput, error) - DeleteMessageWithContext(ctx context.Context, input *sqs.DeleteMessageInput, opts ...request.Option) (*sqs.DeleteMessageOutput, error) -} diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 0251cc700..00bf1c4ef 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -12,6 +12,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "github.com/go-kit/kit/transport/awssqs" "github.com/pborman/uuid" ) @@ -54,7 +55,7 @@ func TestConsumerDeleteBefore(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, @@ -110,7 +111,7 @@ func TestConsumerBadDecode(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, @@ -166,7 +167,7 @@ func TestConsumerBadEndpoint(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, @@ -222,7 +223,7 @@ func TestConsumerBadEncoder(t *testing.T) { }, } }() - errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient awssqs.Client) { + errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), MsgID: *req.MessageId, diff --git a/transport/awssqs/publisher.go b/transport/awssqs/publisher.go index 9fba6c6fe..b08efeed3 100644 --- a/transport/awssqs/publisher.go +++ b/transport/awssqs/publisher.go @@ -7,13 +7,14 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "github.com/go-kit/kit/endpoint" ) // Publisher wraps an sqs client and queue, and provides a method that // implements endpoint.Endpoint. type Publisher struct { - sqsClient Client + sqsClient sqsiface.SQSAPI queueURL string responseQueueURL string enc EncodeRequestFunc @@ -25,7 +26,7 @@ type Publisher struct { // NewPublisher constructs a usable Publisher for a single remote method. func NewPublisher( - sqsClient Client, + sqsClient sqsiface.SQSAPI, queueURL string, responseQueueURL string, enc EncodeRequestFunc, diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/publisher_test.go index 849b6b4f9..308a1e388 100644 --- a/transport/awssqs/publisher_test.go +++ b/transport/awssqs/publisher_test.go @@ -11,6 +11,7 @@ import ( "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/aws/request" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" "github.com/go-kit/kit/transport/awssqs" ) @@ -33,6 +34,7 @@ var names = map[int]string{ // mockClient is a mock of *sqs.SQS. type mockClient struct { + sqsiface.SQSAPI err error sendOutputChan chan *sqs.SendMessageOutput receiveOuputChan chan *sqs.ReceiveMessageOutput @@ -127,7 +129,7 @@ func TestBadDecode(t *testing.T) { func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, errors.New("err!") }, - awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.PublisherAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Set the actual response for the request. return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil }), @@ -233,7 +235,7 @@ func TestSuccessfulPublisher(t *testing.T) { err := json.Unmarshal([]byte(*msg.Body), &response) return response, err }, - awssqs.PublisherAfter(func(ctx context.Context, client awssqs.Client, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.PublisherAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Sets the actual response for the request. if *msg.MessageId == "someMsgID" { return ctx, &sqs.Message{Body: aws.String(string(b))}, nil diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index e78d85b54..502ddf2d6 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -5,6 +5,7 @@ import ( "sync" "github.com/aws/aws-sdk-go/service/sqs" + "github.com/aws/aws-sdk-go/service/sqs/sqsiface" ) // ConsumerRequestFunc may take information from a consumer request result and @@ -30,4 +31,4 @@ type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageIn // fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. // PublisherResponseFunc are only executed in publishers, after a request has been made, // but prior to its response being decoded. So this is the perfect place to fetch actual response. -type PublisherResponseFunc func(context.Context, Client, string, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) +type PublisherResponseFunc func(context.Context, sqsiface.SQSAPI, string, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) From a102c1945ad82f371001dcf49572a2e491b3a04f Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 16:36:11 +0200 Subject: [PATCH 09/18] Fix typo in doc.go --- transport/awssqs/doc.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/awssqs/doc.go b/transport/awssqs/doc.go index b6b5355d5..779c77d69 100644 --- a/transport/awssqs/doc.go +++ b/transport/awssqs/doc.go @@ -1,2 +1,2 @@ -// Package awssqs implements an AWS Simple Query Service transport. +// Package awssqs implements an AWS Simple Queue Service transport. package awssqs From 35587632fa2f11736c793af802ddd306131cf92e Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 17:03:17 +0200 Subject: [PATCH 10/18] make awssqs.HandleMessages and awssqs.HandleSingleMessage private --- transport/awssqs/consumer.go | 12 ++++++------ 1 file changed, 6 insertions(+), 6 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 8d5548c2f..d7fc2dbe0 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -148,11 +148,11 @@ func (c Consumer) Consume(ctx context.Context, receiveMsgInput *sqs.ReceiveMessa if err != nil { return err } - return c.HandleMessages(ctx, out.Messages) + return c.handleMessages(ctx, out.Messages) } -// HandleMessages handles the consumed messages. -func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error { +// handleMessages handles the consumed messages. +func (c Consumer) handleMessages(ctx context.Context, msgs []*sqs.Message) error { ctx, cancel := context.WithCancel(ctx) defer cancel() @@ -190,7 +190,7 @@ func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error } } - if err := c.HandleSingleMessage(ctx, msg, &leftMsgs); err != nil { + if err := c.handleSingleMessage(ctx, msg, &leftMsgs); err != nil { return err } @@ -208,8 +208,8 @@ func (c Consumer) HandleMessages(ctx context.Context, msgs []*sqs.Message) error return nil } -// HandleSingleMessage handles a single sqs message. -func (c Consumer) HandleSingleMessage(ctx context.Context, msg *sqs.Message, leftMsgs *[]*sqs.Message) error { +// handleSingleMessage handles a single sqs message. +func (c Consumer) handleSingleMessage(ctx context.Context, msg *sqs.Message, leftMsgs *[]*sqs.Message) error { req, err := c.dec(ctx, msg) if err != nil { c.errorHandler.Handle(ctx, err) From e1892b5e5f53d664365d9553d060e8881c922d96 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 17:40:46 +0200 Subject: [PATCH 11/18] Replace Publisher by Producer --- transport/awssqs/consumer.go | 4 +- transport/awssqs/encode_decode.go | 4 +- .../awssqs/{publisher.go => producer.go} | 42 +++++++++---------- .../{publisher_test.go => producer_test.go} | 36 ++++++++-------- transport/awssqs/request_response_func.go | 16 +++---- 5 files changed, 51 insertions(+), 51 deletions(-) rename transport/awssqs/{publisher.go => producer.go} (69%) rename transport/awssqs/{publisher_test.go => producer_test.go} (88%) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index d7fc2dbe0..91e93dd21 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -78,7 +78,7 @@ func NewConsumer( // ConsumerOption sets an optional parameter for consumers. type ConsumerOption func(*Consumer) -// ConsumerBefore functions are executed on the publisher request object before the +// ConsumerBefore functions are executed on the producer request object before the // request is decoded. func ConsumerBefore(before ...ConsumerRequestFunc) ConsumerOption { return func(c *Consumer) { c.before = append(c.before, before...) } @@ -254,7 +254,7 @@ func (c Consumer) handleSingleMessage(ctx context.Context, msg *sqs.Message, lef type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) // ConsumerFinalizerFunc can be used to perform work at the end of a request -// from a publisher, after the response has been written to the publisher. The +// from a producer, after the response has been written to the producer. The // principal intended use is for request logging. // Can also be used to delete messages once fully proccessed. type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) diff --git a/transport/awssqs/encode_decode.go b/transport/awssqs/encode_decode.go index 654700cff..1777d690a 100644 --- a/transport/awssqs/encode_decode.go +++ b/transport/awssqs/encode_decode.go @@ -11,7 +11,7 @@ import ( type DecodeRequestFunc func(context.Context, *sqs.Message) (request interface{}, err error) // EncodeRequestFunc encodes the passed payload object into -// an sqs message object. It is designed to be used in Publishers. +// an sqs message object. It is designed to be used in Producers. type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) error // EncodeResponseFunc encodes the passed response object to @@ -19,5 +19,5 @@ type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) type EncodeResponseFunc func(context.Context, *sqs.SendMessageInput, interface{}) error // DecodeResponseFunc extracts a user-domain response object from -// an sqs message object. It is designed to be used in Publishers. +// an sqs message object. It is designed to be used in Producers. type DecodeResponseFunc func(context.Context, *sqs.Message) (response interface{}, err error) diff --git a/transport/awssqs/publisher.go b/transport/awssqs/producer.go similarity index 69% rename from transport/awssqs/publisher.go rename to transport/awssqs/producer.go index b08efeed3..4bf12958e 100644 --- a/transport/awssqs/publisher.go +++ b/transport/awssqs/producer.go @@ -11,29 +11,29 @@ import ( "github.com/go-kit/kit/endpoint" ) -// Publisher wraps an sqs client and queue, and provides a method that +// Producer wraps an sqs client and queue, and provides a method that // implements endpoint.Endpoint. -type Publisher struct { +type Producer struct { sqsClient sqsiface.SQSAPI queueURL string responseQueueURL string enc EncodeRequestFunc dec DecodeResponseFunc - before []PublisherRequestFunc - after []PublisherResponseFunc + before []ProducerRequestFunc + after []ProducerResponseFunc timeout time.Duration } -// NewPublisher constructs a usable Publisher for a single remote method. -func NewPublisher( +// NewProducer constructs a usable Producer for a single remote method. +func NewProducer( sqsClient sqsiface.SQSAPI, queueURL string, responseQueueURL string, enc EncodeRequestFunc, dec DecodeResponseFunc, - options ...PublisherOption, -) *Publisher { - p := &Publisher{ + options ...ProducerOption, +) *Producer { + p := &Producer{ sqsClient: sqsClient, queueURL: queueURL, responseQueueURL: responseQueueURL, @@ -47,29 +47,29 @@ func NewPublisher( return p } -// PublisherOption sets an optional parameter for clients. -type PublisherOption func(*Publisher) +// ProducerOption sets an optional parameter for clients. +type ProducerOption func(*Producer) -// PublisherBefore sets the RequestFuncs that are applied to the outgoing sqs +// ProducerBefore sets the RequestFuncs that are applied to the outgoing sqs // request before it's invoked. -func PublisherBefore(before ...PublisherRequestFunc) PublisherOption { - return func(p *Publisher) { p.before = append(p.before, before...) } +func ProducerBefore(before ...ProducerRequestFunc) ProducerOption { + return func(p *Producer) { p.before = append(p.before, before...) } } -// PublisherAfter sets the ClientResponseFuncs applied to the incoming sqs +// ProducerAfter sets the ClientResponseFuncs applied to the incoming sqs // request prior to it being decoded. This is useful for obtaining anything off // of the response and adding onto the context prior to decoding. -func PublisherAfter(after ...PublisherResponseFunc) PublisherOption { - return func(p *Publisher) { p.after = append(p.after, after...) } +func ProducerAfter(after ...ProducerResponseFunc) ProducerOption { + return func(p *Producer) { p.after = append(p.after, after...) } } -// PublisherTimeout sets the available timeout for an sqs request. -func PublisherTimeout(timeout time.Duration) PublisherOption { - return func(p *Publisher) { p.timeout = timeout } +// ProducerTimeout sets the available timeout for an sqs request. +func ProducerTimeout(timeout time.Duration) ProducerOption { + return func(p *Producer) { p.timeout = timeout } } // Endpoint returns a usable endpoint that invokes the remote endpoint. -func (p Publisher) Endpoint() endpoint.Endpoint { +func (p Producer) Endpoint() endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { ctx, cancel := context.WithTimeout(ctx, p.timeout) defer cancel() diff --git a/transport/awssqs/publisher_test.go b/transport/awssqs/producer_test.go similarity index 88% rename from transport/awssqs/publisher_test.go rename to transport/awssqs/producer_test.go index 308a1e388..3efbaab8f 100644 --- a/transport/awssqs/publisher_test.go +++ b/transport/awssqs/producer_test.go @@ -79,7 +79,7 @@ func TestBadEncode(t *testing.T) { mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), } - pub := awssqs.NewPublisher( + pub := awssqs.NewProducer( mock, queueURL, responseQueueURL, @@ -121,7 +121,7 @@ func TestBadDecode(t *testing.T) { queueURL := "someURL" responseQueueURL := "someOtherURL" - pub := awssqs.NewPublisher( + pub := awssqs.NewProducer( mock, queueURL, responseQueueURL, @@ -129,7 +129,7 @@ func TestBadDecode(t *testing.T) { func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, errors.New("err!") }, - awssqs.PublisherAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Set the actual response for the request. return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil }), @@ -158,15 +158,15 @@ func TestBadDecode(t *testing.T) { } } -// TestPublisherTimeout ensures that the publisher timeout mechanism works. -func TestPublisherTimeout(t *testing.T) { +// TestProducerTimeout ensures that the producer timeout mechanism works. +func TestProducerTimeout(t *testing.T) { sendOutputChan := make(chan *sqs.SendMessageOutput) mock := &mockClient{ sendOutputChan: sendOutputChan, } queueURL := "someURL" responseQueueURL := "someOtherURL" - pub := awssqs.NewPublisher( + pub := awssqs.NewProducer( mock, queueURL, responseQueueURL, @@ -174,7 +174,7 @@ func TestPublisherTimeout(t *testing.T) { func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil }, - awssqs.PublisherTimeout(50*time.Millisecond), + awssqs.ProducerTimeout(50*time.Millisecond), ) var err error @@ -202,8 +202,8 @@ func TestPublisherTimeout(t *testing.T) { } } -// TestSuccessfulPublisher ensures that the publisher mechanisms work. -func TestSuccessfulPublisher(t *testing.T) { +// TestSuccessfulProducer ensures that the producer mechanisms work. +func TestSuccessfulProducer(t *testing.T) { mockReq := testReq{437} mockRes := testRes{ Squadron: mockReq.Squadron, @@ -225,7 +225,7 @@ func TestSuccessfulPublisher(t *testing.T) { queueURL := "someURL" responseQueueURL := "someOtherURL" - pub := awssqs.NewPublisher( + pub := awssqs.NewProducer( mock, queueURL, responseQueueURL, @@ -235,7 +235,7 @@ func TestSuccessfulPublisher(t *testing.T) { err := json.Unmarshal([]byte(*msg.Body), &response) return response, err }, - awssqs.PublisherAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Sets the actual response for the request. if *msg.MessageId == "someMsgID" { return ctx, &sqs.Message{Body: aws.String(string(b))}, nil @@ -279,8 +279,8 @@ func TestSuccessfulPublisher(t *testing.T) { } } -// TestSuccessfulPublisherNoResponse ensures that the publisher response mechanism works. -func TestSuccessfulPublisherNoResponse(t *testing.T) { +// TestSuccessfulProducerNoResponse ensures that the producer response mechanism works. +func TestSuccessfulProducerNoResponse(t *testing.T) { mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), @@ -289,7 +289,7 @@ func TestSuccessfulPublisherNoResponse(t *testing.T) { queueURL := "someURL" responseQueueURL := "someOtherURL" - pub := awssqs.NewPublisher( + pub := awssqs.NewProducer( mock, queueURL, responseQueueURL, @@ -318,9 +318,9 @@ func TestSuccessfulPublisherNoResponse(t *testing.T) { } } -// TestPublisherWithBefore adds a PublisherBefore function that adds a message attribute. +// TestProducerWithBefore adds a ProducerBefore function that adds a message attribute. // This test ensures that the the before functions work as expected. -func TestPublisherWithBefore(t *testing.T) { +func TestProducerWithBefore(t *testing.T) { mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), @@ -329,13 +329,13 @@ func TestPublisherWithBefore(t *testing.T) { queueURL := "someURL" responseQueueURL := "someOtherURL" - pub := awssqs.NewPublisher( + pub := awssqs.NewProducer( mock, queueURL, responseQueueURL, awssqs.EncodeJSONRequest, awssqs.NoResponseDecode, - awssqs.PublisherBefore(func(c context.Context, s *sqs.SendMessageInput, _ string) context.Context { + awssqs.ProducerBefore(func(c context.Context, s *sqs.SendMessageInput, _ string) context.Context { if s.MessageAttributes == nil { s.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) } diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index 502ddf2d6..e074040b5 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -14,21 +14,21 @@ import ( // use cases eg. in Consumer : extract message into context, or filter received messages. type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context -// PublisherRequestFunc may take information from a publisher request and put it into a -// request context, or add some informations to SendMessageInput. In Publishers, +// ProducerRequestFunc may take information from a producer request and put it into a +// request context, or add some informations to SendMessageInput. In Producers, // RequestFuncs are executed prior to publishing the message but after encoding. -// use cases eg. in Publisher : enforce some message attributes to SendMessageInput -type PublisherRequestFunc func(context.Context, *sqs.SendMessageInput, string) context.Context +// use cases eg. in Producer : enforce some message attributes to SendMessageInput +type ProducerRequestFunc func(context.Context, *sqs.SendMessageInput, string) context.Context // ConsumerResponseFunc may take information from a request context and use it to -// manipulate a Publisher. ConsumerResponseFunc are only executed in +// manipulate a Producer. ConsumerResponseFunc are only executed in // consumers, after invoking the endpoint but prior to publishing a reply. // use cases eg. : Pipe information from request message to response MessageInput, // delete msg from queue or update leftMsgs slice type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message, *sync.Mutex) context.Context -// PublisherResponseFunc may take information from an sqs.SendMessageOutput and +// ProducerResponseFunc may take information from an sqs.SendMessageOutput and // fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. -// PublisherResponseFunc are only executed in publishers, after a request has been made, +// ProducerResponseFunc are only executed in producers, after a request has been made, // but prior to its response being decoded. So this is the perfect place to fetch actual response. -type PublisherResponseFunc func(context.Context, sqsiface.SQSAPI, string, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) +type ProducerResponseFunc func(context.Context, sqsiface.SQSAPI, string, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) From 3abc4882e8b5b6adad05e9b4626d57bff4bb7fd4 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 12 Oct 2020 17:58:41 +0200 Subject: [PATCH 12/18] Replace sqs by SQS in comments --- transport/awssqs/consumer.go | 8 ++++---- transport/awssqs/consumer_test.go | 10 +++++----- transport/awssqs/encode_decode.go | 8 ++++---- transport/awssqs/producer.go | 12 ++++++------ transport/awssqs/request_response_func.go | 4 ++-- 5 files changed, 21 insertions(+), 21 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 91e93dd21..0fb96ff5d 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -13,7 +13,7 @@ import ( "github.com/go-kit/kit/transport" ) -// Delete is a type to indicate when the consumed message should be deleted +// Delete is a type to indicate when the consumed message should be deleted. type Delete int const ( @@ -26,7 +26,7 @@ const ( Never ) -// Consumer wraps an endpoint and provides a handler for sqs messages. +// Consumer wraps an endpoint and provides a handler for SQS messages. type Consumer struct { sqsClient sqsiface.SQSAPI e endpoint.Endpoint @@ -208,7 +208,7 @@ func (c Consumer) handleMessages(ctx context.Context, msgs []*sqs.Message) error return nil } -// handleSingleMessage handles a single sqs message. +// handleSingleMessage handles a single SQS message. func (c Consumer) handleSingleMessage(ctx context.Context, msg *sqs.Message, leftMsgs *[]*sqs.Message) error { req, err := c.dec(ctx, msg) if err != nil { @@ -274,7 +274,7 @@ func DefaultErrorEncoder(context.Context, error, *sqs.Message, sqsiface.SQSAPI) } // DoNotExtendVisibilityTimeout is the default value for the consumer's visibilityTimeoutFunc. -// It returns no error and does nothing +// It returns no error and does nothing. func DoNotExtendVisibilityTimeout(context.Context, sqsiface.SQSAPI, string, int64, *[]*sqs.Message, *sync.Mutex) error { return nil } diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 00bf1c4ef..2527e76f7 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -358,13 +358,13 @@ func TestConsumerSuccessNoReply(t *testing.T) { return case <-time.After(200 * time.Millisecond): - // As expected, we did not receive any response from consumer + // As expected, we did not receive any response from consumer. return } } // TestConsumerBeforeFilterMessages checks if consumer before is called as expected. -// Here before is used to filter messages before processing +// Here before is used to filter messages before processing. func TestConsumerBeforeFilterMessages(t *testing.T) { obj1 := testReq{ Squadron: 436, @@ -423,7 +423,7 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { awssqs.EncodeJSONResponse, queueURL, awssqs.ConsumerBefore(func(ctx context.Context, msgs *[]*sqs.Message) context.Context { - // delete a message that is not destined to the consumer + // Filter a message that is not destined to the consumer. msgsCopy := *msgs for index, msg := range *msgs { if recipient, exists := msg.MessageAttributes["recipient"]; !exists || *recipient.StringValue != "me" { @@ -457,14 +457,14 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { if have := res; want != have { t.Errorf("want %v, have %v", want, have) } - // Try fetching responses again + // Try fetching responses again. select { case receiveOutput = <-mock.receiveOuputChan: t.Errorf("received second output when only one was expected, have %v", receiveOutput) return case <-time.After(200 * time.Millisecond): - // As expected, we did not receive a second response from consumer + // As expected, we did not receive a second response from consumer. return } } diff --git a/transport/awssqs/encode_decode.go b/transport/awssqs/encode_decode.go index 1777d690a..0f6b0f3de 100644 --- a/transport/awssqs/encode_decode.go +++ b/transport/awssqs/encode_decode.go @@ -7,17 +7,17 @@ import ( ) // DecodeRequestFunc extracts a user-domain request object from -// an sqs message object. It is designed to be used in Consumers. +// an SQS message object. It is designed to be used in Consumers. type DecodeRequestFunc func(context.Context, *sqs.Message) (request interface{}, err error) // EncodeRequestFunc encodes the passed payload object into -// an sqs message object. It is designed to be used in Producers. +// an SQS message object. It is designed to be used in Producers. type EncodeRequestFunc func(context.Context, *sqs.SendMessageInput, interface{}) error // EncodeResponseFunc encodes the passed response object to -// an sqs message object. It is designed to be used in Consumers. +// an SQS message object. It is designed to be used in Consumers. type EncodeResponseFunc func(context.Context, *sqs.SendMessageInput, interface{}) error // DecodeResponseFunc extracts a user-domain response object from -// an sqs message object. It is designed to be used in Producers. +// an SQS message object. It is designed to be used in Producers. type DecodeResponseFunc func(context.Context, *sqs.Message) (response interface{}, err error) diff --git a/transport/awssqs/producer.go b/transport/awssqs/producer.go index 4bf12958e..89df3cd13 100644 --- a/transport/awssqs/producer.go +++ b/transport/awssqs/producer.go @@ -11,7 +11,7 @@ import ( "github.com/go-kit/kit/endpoint" ) -// Producer wraps an sqs client and queue, and provides a method that +// Producer wraps an SQS client and queue, and provides a method that // implements endpoint.Endpoint. type Producer struct { sqsClient sqsiface.SQSAPI @@ -50,20 +50,20 @@ func NewProducer( // ProducerOption sets an optional parameter for clients. type ProducerOption func(*Producer) -// ProducerBefore sets the RequestFuncs that are applied to the outgoing sqs +// ProducerBefore sets the RequestFuncs that are applied to the outgoing SQS // request before it's invoked. func ProducerBefore(before ...ProducerRequestFunc) ProducerOption { return func(p *Producer) { p.before = append(p.before, before...) } } -// ProducerAfter sets the ClientResponseFuncs applied to the incoming sqs -// request prior to it being decoded. This is useful for obtaining anything off -// of the response and adding onto the context prior to decoding. +// ProducerAfter sets the ClientResponseFuncs applied to the incoming SQS +// request prior to it being decoded. This is useful for obtaining the response +// and adding any information onto the context prior to decoding. func ProducerAfter(after ...ProducerResponseFunc) ProducerOption { return func(p *Producer) { p.after = append(p.after, after...) } } -// ProducerTimeout sets the available timeout for an sqs request. +// ProducerTimeout sets the available timeout for an SQS request. func ProducerTimeout(timeout time.Duration) ProducerOption { return func(p *Producer) { p.timeout = timeout } } diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index e074040b5..ea8852542 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -17,14 +17,14 @@ type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context // ProducerRequestFunc may take information from a producer request and put it into a // request context, or add some informations to SendMessageInput. In Producers, // RequestFuncs are executed prior to publishing the message but after encoding. -// use cases eg. in Producer : enforce some message attributes to SendMessageInput +// use cases eg. in Producer : enforce some message attributes to SendMessageInput. type ProducerRequestFunc func(context.Context, *sqs.SendMessageInput, string) context.Context // ConsumerResponseFunc may take information from a request context and use it to // manipulate a Producer. ConsumerResponseFunc are only executed in // consumers, after invoking the endpoint but prior to publishing a reply. // use cases eg. : Pipe information from request message to response MessageInput, -// delete msg from queue or update leftMsgs slice +// delete msg from queue or update leftMsgs slice. type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message, *sync.Mutex) context.Context // ProducerResponseFunc may take information from an sqs.SendMessageOutput and From a340fef1663cfb35ff9b357e350208764e5d262c Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Tue, 13 Oct 2020 09:46:42 +0200 Subject: [PATCH 13/18] Delete responseQueueURL from producer --- transport/awssqs/producer.go | 46 ++++++++++++++--------- transport/awssqs/producer_test.go | 25 +++++------- transport/awssqs/request_response_func.go | 4 +- 3 files changed, 40 insertions(+), 35 deletions(-) diff --git a/transport/awssqs/producer.go b/transport/awssqs/producer.go index 89df3cd13..d505b9231 100644 --- a/transport/awssqs/producer.go +++ b/transport/awssqs/producer.go @@ -11,35 +11,40 @@ import ( "github.com/go-kit/kit/endpoint" ) +type contextKey int + +const ( + // ContextKeyResponseQueueURL is the context key that allows fetching + // the response queue URL from context + ContextKeyResponseQueueURL contextKey = iota +) + // Producer wraps an SQS client and queue, and provides a method that // implements endpoint.Endpoint. type Producer struct { - sqsClient sqsiface.SQSAPI - queueURL string - responseQueueURL string - enc EncodeRequestFunc - dec DecodeResponseFunc - before []ProducerRequestFunc - after []ProducerResponseFunc - timeout time.Duration + sqsClient sqsiface.SQSAPI + queueURL string + enc EncodeRequestFunc + dec DecodeResponseFunc + before []ProducerRequestFunc + after []ProducerResponseFunc + timeout time.Duration } // NewProducer constructs a usable Producer for a single remote method. func NewProducer( sqsClient sqsiface.SQSAPI, queueURL string, - responseQueueURL string, enc EncodeRequestFunc, dec DecodeResponseFunc, options ...ProducerOption, ) *Producer { p := &Producer{ - sqsClient: sqsClient, - queueURL: queueURL, - responseQueueURL: responseQueueURL, - enc: enc, - dec: dec, - timeout: 20 * time.Second, + sqsClient: sqsClient, + queueURL: queueURL, + enc: enc, + dec: dec, + timeout: 20 * time.Second, } for _, option := range options { option(p) @@ -68,6 +73,13 @@ func ProducerTimeout(timeout time.Duration) ProducerOption { return func(p *Producer) { p.timeout = timeout } } +// SetProducerResponseQueueURL sets this as before or after function +func SetProducerResponseQueueURL(url string) ProducerRequestFunc { + return func(ctx context.Context, _ *sqs.SendMessageInput) context.Context { + return context.WithValue(ctx, ContextKeyResponseQueueURL, url) + } +} + // Endpoint returns a usable endpoint that invokes the remote endpoint. func (p Producer) Endpoint() endpoint.Endpoint { return func(ctx context.Context, request interface{}) (interface{}, error) { @@ -81,7 +93,7 @@ func (p Producer) Endpoint() endpoint.Endpoint { } for _, f := range p.before { - ctx = f(ctx, &msgInput, p.responseQueueURL) + ctx = f(ctx, &msgInput) } output, err := p.sqsClient.SendMessageWithContext(ctx, &msgInput) @@ -91,7 +103,7 @@ func (p Producer) Endpoint() endpoint.Endpoint { var responseMsg *sqs.Message for _, f := range p.after { - ctx, responseMsg, err = f(ctx, p.sqsClient, p.responseQueueURL, output) + ctx, responseMsg, err = f(ctx, p.sqsClient, output) if err != nil { return nil, err } diff --git a/transport/awssqs/producer_test.go b/transport/awssqs/producer_test.go index 3efbaab8f..e05f5655d 100644 --- a/transport/awssqs/producer_test.go +++ b/transport/awssqs/producer_test.go @@ -75,14 +75,12 @@ func (mock *mockClient) ChangeMessageVisibilityWithContext(ctx aws.Context, inpu // TestBadEncode tests if encode errors are handled properly. func TestBadEncode(t *testing.T) { queueURL := "someURL" - responseQueueURL := "someOtherURL" mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), } pub := awssqs.NewProducer( mock, queueURL, - responseQueueURL, func(context.Context, *sqs.SendMessageInput, interface{}) error { return errors.New("err!") }, func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil }, ) @@ -120,16 +118,14 @@ func TestBadDecode(t *testing.T) { }() queueURL := "someURL" - responseQueueURL := "someOtherURL" pub := awssqs.NewProducer( mock, queueURL, - responseQueueURL, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, errors.New("err!") }, - awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Set the actual response for the request. return ctx, &sqs.Message{Body: aws.String("someMsgContent")}, nil }), @@ -165,11 +161,9 @@ func TestProducerTimeout(t *testing.T) { sendOutputChan: sendOutputChan, } queueURL := "someURL" - responseQueueURL := "someOtherURL" pub := awssqs.NewProducer( mock, queueURL, - responseQueueURL, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, func(context.Context, *sqs.Message) (response interface{}, err error) { return struct{}{}, nil @@ -224,18 +218,16 @@ func TestSuccessfulProducer(t *testing.T) { }() queueURL := "someURL" - responseQueueURL := "someOtherURL" pub := awssqs.NewProducer( mock, queueURL, - responseQueueURL, awssqs.EncodeJSONRequest, func(_ context.Context, msg *sqs.Message) (interface{}, error) { response := testRes{} err := json.Unmarshal([]byte(*msg.Body), &response) return response, err }, - awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, responseQueueURL string, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { + awssqs.ProducerAfter(func(ctx context.Context, _ sqsiface.SQSAPI, msg *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) { // Sets the actual response for the request. if *msg.MessageId == "someMsgID" { return ctx, &sqs.Message{Body: aws.String(string(b))}, nil @@ -288,11 +280,9 @@ func TestSuccessfulProducerNoResponse(t *testing.T) { } queueURL := "someURL" - responseQueueURL := "someOtherURL" pub := awssqs.NewProducer( mock, queueURL, - responseQueueURL, awssqs.EncodeJSONRequest, awssqs.NoResponseDecode, ) @@ -318,8 +308,10 @@ func TestSuccessfulProducerNoResponse(t *testing.T) { } } -// TestProducerWithBefore adds a ProducerBefore function that adds a message attribute. -// This test ensures that the the before functions work as expected. +// TestProducerWithBefore adds a ProducerBefore function that adds responseQueueURL to context, +// and another on that adds it as a message attribute to outgoing message. +// This test ensures that setting multiple before functions work as expected +// and that SetProducerResponseQueueURL works as expected. func TestProducerWithBefore(t *testing.T) { mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), @@ -332,10 +324,11 @@ func TestProducerWithBefore(t *testing.T) { pub := awssqs.NewProducer( mock, queueURL, - responseQueueURL, awssqs.EncodeJSONRequest, awssqs.NoResponseDecode, - awssqs.ProducerBefore(func(c context.Context, s *sqs.SendMessageInput, _ string) context.Context { + awssqs.ProducerBefore(awssqs.SetProducerResponseQueueURL(responseQueueURL)), + awssqs.ProducerBefore(func(c context.Context, s *sqs.SendMessageInput) context.Context { + responseQueueURL := c.Value(awssqs.ContextKeyResponseQueueURL).(string) if s.MessageAttributes == nil { s.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) } diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index ea8852542..84914268c 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -18,7 +18,7 @@ type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context // request context, or add some informations to SendMessageInput. In Producers, // RequestFuncs are executed prior to publishing the message but after encoding. // use cases eg. in Producer : enforce some message attributes to SendMessageInput. -type ProducerRequestFunc func(context.Context, *sqs.SendMessageInput, string) context.Context +type ProducerRequestFunc func(ctx context.Context, input *sqs.SendMessageInput) context.Context // ConsumerResponseFunc may take information from a request context and use it to // manipulate a Producer. ConsumerResponseFunc are only executed in @@ -31,4 +31,4 @@ type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageIn // fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. // ProducerResponseFunc are only executed in producers, after a request has been made, // but prior to its response being decoded. So this is the perfect place to fetch actual response. -type ProducerResponseFunc func(context.Context, sqsiface.SQSAPI, string, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) +type ProducerResponseFunc func(context.Context, sqsiface.SQSAPI, *sqs.SendMessageOutput) (context.Context, *sqs.Message, error) From 0223d131e18b893666bec475edc8814b58e8f2b5 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Tue, 13 Oct 2020 09:47:50 +0200 Subject: [PATCH 14/18] Name VisibilityTimeoutFunc parameters to ease comprehension --- transport/awssqs/consumer.go | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 0fb96ff5d..07e8d2225 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -263,7 +263,7 @@ type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) // this can be used to provide custom visibility timeout extension such as doubling it everytime // it gets close to being reached. // VisibilityTimeoutFunc will need to check that the provided context is not done and return once it is. -type VisibilityTimeoutFunc func(context.Context, sqsiface.SQSAPI, string, int64, *[]*sqs.Message, *sync.Mutex) error +type VisibilityTimeoutFunc func(ctx context.Context, client sqsiface.SQSAPI, queueURL string, visibilityTimeout int64, leftMsgs *[]*sqs.Message, leftMsgsMux *sync.Mutex) error // WantReplyFunc encapsulates logic to check whether message awaits response or not // for example check for a given message attribute value. From f68f079a54522b93f673bbae934d35957ae2c334 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 14 Oct 2020 14:38:07 +0200 Subject: [PATCH 15/18] Simplify code: Consumer now only supports single message serving. receiving messages, handling multiple messages and updating visibilitytimeout must be done by the users. --- transport/awssqs/consumer.go | 178 ++++++------------ transport/awssqs/consumer_test.go | 215 +++++++--------------- transport/awssqs/request_response_func.go | 7 +- 3 files changed, 129 insertions(+), 271 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 07e8d2225..534846ff8 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -13,36 +13,19 @@ import ( "github.com/go-kit/kit/transport" ) -// Delete is a type to indicate when the consumed message should be deleted. -type Delete int - -const ( - // BeforeHandle deletes the message before starting to handle it. - BeforeHandle Delete = iota - // AfterHandle deletes the message once it has been fully processed. - // This is the consumer's default value. - AfterHandle - // Never does not delete the message. - Never -) - // Consumer wraps an endpoint and provides a handler for SQS messages. type Consumer struct { - sqsClient sqsiface.SQSAPI - e endpoint.Endpoint - dec DecodeRequestFunc - enc EncodeResponseFunc - wantRep WantReplyFunc - queueURL string - visibilityTimeout int64 - visibilityTimeoutFunc VisibilityTimeoutFunc - leftMsgsMux *sync.Mutex - before []ConsumerRequestFunc - after []ConsumerResponseFunc - errorEncoder ErrorEncoder - finalizer []ConsumerFinalizerFunc - errorHandler transport.ErrorHandler - deleteMessage Delete + sqsClient sqsiface.SQSAPI + e endpoint.Endpoint + dec DecodeRequestFunc + enc EncodeResponseFunc + wantRep WantReplyFunc + queueURL string + before []ConsumerRequestFunc + after []ConsumerResponseFunc + errorEncoder ErrorEncoder + finalizer []ConsumerFinalizerFunc + errorHandler transport.ErrorHandler } // NewConsumer constructs a new Consumer, which provides a Consume method @@ -56,18 +39,14 @@ func NewConsumer( options ...ConsumerOption, ) *Consumer { s := &Consumer{ - sqsClient: sqsClient, - e: e, - dec: dec, - enc: enc, - wantRep: DoNotRespond, - queueURL: queueURL, - visibilityTimeout: int64(30), - visibilityTimeoutFunc: DoNotExtendVisibilityTimeout, - leftMsgsMux: &sync.Mutex{}, - errorEncoder: DefaultErrorEncoder, - errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), - deleteMessage: AfterHandle, + sqsClient: sqsClient, + e: e, + dec: dec, + enc: enc, + wantRep: DoNotRespond, + queueURL: queueURL, + errorEncoder: DefaultErrorEncoder, + errorHandler: transport.NewLogErrorHandler(log.NewNopLogger()), } for _, option := range options { option(s) @@ -98,20 +77,6 @@ func ConsumerErrorEncoder(ee ErrorEncoder) ConsumerOption { return func(c *Consumer) { c.errorEncoder = ee } } -// ConsumerVisbilityTimeOutFunc is used to extend the visibility timeout -// for messages while the consumer processes them. -// VisibilityTimeoutFunc will need to check that the provided context is not done. -// By default, visibility timeout are not extended. -func ConsumerVisbilityTimeOutFunc(vtFunc VisibilityTimeoutFunc) ConsumerOption { - return func(c *Consumer) { c.visibilityTimeoutFunc = vtFunc } -} - -// ConsumerVisibilityTimeout overrides the default value for the consumer's -// visibilityTimeout field. -func ConsumerVisibilityTimeout(visibilityTimeout int64) ConsumerOption { - return func(c *Consumer) { c.visibilityTimeout = visibilityTimeout } -} - // ConsumerWantReplyFunc overrides the default value for the consumer's // wantRep field. func ConsumerWantReplyFunc(replyFunc WantReplyFunc) ConsumerOption { @@ -132,84 +97,55 @@ func ConsumerFinalizer(f ...ConsumerFinalizerFunc) ConsumerOption { return func(c *Consumer) { c.finalizer = f } } -// ConsumerDeleteMessage overrides the default value for the consumer's -// deleteMessage field to indicate when the consumed messages should be deleted. -func ConsumerDeleteMessage(delete Delete) ConsumerOption { - return func(c *Consumer) { c.deleteMessage = delete } +// ConsumerDeleteMessageBefore returns a ConsumerOption that appends a function +// that delete the message from queue to the list of consumer's before functions. +func ConsumerDeleteMessageBefore() ConsumerOption { + return func(c *Consumer) { + deleteBefore := func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message) context.Context { + if err := deleteMessage(ctx, c.sqsClient, c.queueURL, msg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + cancel() + } + return ctx + } + c.before = append(c.before, deleteBefore) + } } -// Consume calls ReceiveMessageWithContext and handles messages having an -// sqs.ReceiveMessageInput as parameter allows each user to have his own receive configuration. -// That said, this method overrides the queueURL for the provided ReceiveMessageInput to ensure -// the messages are retrieved from the consumer's configured queue. -func (c Consumer) Consume(ctx context.Context, receiveMsgInput *sqs.ReceiveMessageInput) error { - receiveMsgInput.QueueUrl = &c.queueURL - out, err := c.sqsClient.ReceiveMessageWithContext(ctx, receiveMsgInput) - if err != nil { - return err +// ConsumerDeleteMessageAfter returns a ConsumerOption that appends a function +// that delete a message from queue to the list of consumer's after functions. +func ConsumerDeleteMessageAfter() ConsumerOption { + return func(c *Consumer) { + deleteAfter := func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message, _ *sqs.SendMessageInput) context.Context { + if err := deleteMessage(ctx, c.sqsClient, c.queueURL, msg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + cancel() + } + return ctx + } + c.after = append(c.after, deleteAfter) } - return c.handleMessages(ctx, out.Messages) } -// handleMessages handles the consumed messages. -func (c Consumer) handleMessages(ctx context.Context, msgs []*sqs.Message) error { +// ServeMessage serves an SQS message. +func (c Consumer) ServeMessage(ctx context.Context, msg *sqs.Message) error { ctx, cancel := context.WithCancel(ctx) defer cancel() - // Copy received messages slice in leftMsgs slice - // leftMsgs will be used by the consumer's visibilityTimeoutFunc to extend the - // visibility timeout for the messages that have not been processed yet. - leftMsgs := []*sqs.Message{} - leftMsgs = append(leftMsgs, msgs...) - - visibilityTimeoutCtx, cancel := context.WithCancel(ctx) - defer cancel() - go c.visibilityTimeoutFunc(visibilityTimeoutCtx, c.sqsClient, c.queueURL, c.visibilityTimeout, &leftMsgs, c.leftMsgsMux) - if len(c.finalizer) > 0 { defer func() { for _, f := range c.finalizer { - f(ctx, &msgs) + f(ctx, msg) } }() } for _, f := range c.before { - ctx = f(ctx, &msgs) - } - - for _, msg := range msgs { - if c.deleteMessage == BeforeHandle { - if _, err := c.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ - QueueUrl: &c.queueURL, - ReceiptHandle: msg.ReceiptHandle, - }); err != nil { - c.errorHandler.Handle(ctx, err) - c.errorEncoder(ctx, err, msg, c.sqsClient) - return err - } - } - - if err := c.handleSingleMessage(ctx, msg, &leftMsgs); err != nil { - return err - } - - if c.deleteMessage == AfterHandle { - if _, err := c.sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ - QueueUrl: &c.queueURL, - ReceiptHandle: msg.ReceiptHandle, - }); err != nil { - c.errorHandler.Handle(ctx, err) - c.errorEncoder(ctx, err, msg, c.sqsClient) - return err - } - } + ctx = f(ctx, cancel, msg) } - return nil -} -// handleSingleMessage handles a single SQS message. -func (c Consumer) handleSingleMessage(ctx context.Context, msg *sqs.Message, leftMsgs *[]*sqs.Message) error { req, err := c.dec(ctx, msg) if err != nil { c.errorHandler.Handle(ctx, err) @@ -226,7 +162,7 @@ func (c Consumer) handleSingleMessage(ctx context.Context, msg *sqs.Message, lef responseMsg := sqs.SendMessageInput{} for _, f := range c.after { - ctx = f(ctx, msg, &responseMsg, leftMsgs, c.leftMsgsMux) + ctx = f(ctx, cancel, msg, &responseMsg) } if !c.wantRep(ctx, msg) { @@ -257,7 +193,7 @@ type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClie // from a producer, after the response has been written to the producer. The // principal intended use is for request logging. // Can also be used to delete messages once fully proccessed. -type ConsumerFinalizerFunc func(ctx context.Context, msg *[]*sqs.Message) +type ConsumerFinalizerFunc func(ctx context.Context, msg *sqs.Message) // VisibilityTimeoutFunc encapsulates logic to extend messages visibility timeout. // this can be used to provide custom visibility timeout extension such as doubling it everytime @@ -273,10 +209,12 @@ type WantReplyFunc func(context.Context, *sqs.Message) bool func DefaultErrorEncoder(context.Context, error, *sqs.Message, sqsiface.SQSAPI) { } -// DoNotExtendVisibilityTimeout is the default value for the consumer's visibilityTimeoutFunc. -// It returns no error and does nothing. -func DoNotExtendVisibilityTimeout(context.Context, sqsiface.SQSAPI, string, int64, *[]*sqs.Message, *sync.Mutex) error { - return nil +func deleteMessage(ctx context.Context, sqsClient sqsiface.SQSAPI, queueURL string, msg *sqs.Message) error { + _, err := sqsClient.DeleteMessageWithContext(ctx, &sqs.DeleteMessageInput{ + QueueUrl: &queueURL, + ReceiptHandle: msg.ReceiptHandle, + }) + return err } // DoNotRespond is a WantReplyFunc and is the default value for consumer's wantRep field. diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 2527e76f7..a9c844a68 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "sync" "testing" "time" @@ -45,16 +44,6 @@ func TestConsumerDeleteBefore(t *testing.T) { receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), deleteError: fmt.Errorf("delete err!"), } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: []*sqs.Message{ - { - Body: aws.String("MessageBody"), - MessageId: aws.String("fakeMsgID"), - }, - }, - } - }() errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), @@ -68,14 +57,17 @@ func TestConsumerDeleteBefore(t *testing.T) { }) consumer := awssqs.NewConsumer(mock, func(context.Context, interface{}) (interface{}, error) { return struct{}{}, nil }, - func(context.Context, *sqs.Message) (interface{}, error) { return nil, errors.New("decode err!") }, + func(context.Context, *sqs.Message) (interface{}, error) { return nil, nil }, func(context.Context, *sqs.SendMessageInput, interface{}) error { return nil }, queueURL, errEncoder, - awssqs.ConsumerDeleteMessage(awssqs.BeforeHandle), + awssqs.ConsumerDeleteMessageBefore(), ) - consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(context.Background(), &sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) var receiveOutput *sqs.ReceiveMessageOutput select { @@ -101,16 +93,6 @@ func TestConsumerBadDecode(t *testing.T) { sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: []*sqs.Message{ - { - Body: aws.String("MessageBody"), - MessageId: aws.String("fakeMsgID"), - }, - }, - } - }() errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), @@ -131,7 +113,10 @@ func TestConsumerBadDecode(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(context.Background(), &sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) var receiveOutput *sqs.ReceiveMessageOutput select { @@ -157,16 +142,6 @@ func TestConsumerBadEndpoint(t *testing.T) { sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: []*sqs.Message{ - { - Body: aws.String("MessageBody"), - MessageId: aws.String("fakeMsgID"), - }, - }, - } - }() errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), @@ -187,7 +162,10 @@ func TestConsumerBadEndpoint(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(context.Background(), &sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) var receiveOutput *sqs.ReceiveMessageOutput select { @@ -213,16 +191,6 @@ func TestConsumerBadEncoder(t *testing.T) { sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: []*sqs.Message{ - { - Body: aws.String("MessageBody"), - MessageId: aws.String("fakeMsgID"), - }, - }, - } - }() errEncoder := awssqs.ConsumerErrorEncoder(func(ctx context.Context, err error, req *sqs.Message, sqsClient sqsiface.SQSAPI) { publishError := sqsError{ Err: err.Error(), @@ -243,7 +211,10 @@ func TestConsumerBadEncoder(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(context.Background(), &sqs.Message{ + Body: aws.String("MessageBody"), + MessageId: aws.String("fakeMsgID"), + }) var receiveOutput *sqs.ReceiveMessageOutput select { @@ -276,16 +247,6 @@ func TestConsumerSuccess(t *testing.T) { sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: []*sqs.Message{ - { - Body: aws.String(string(b)), - MessageId: aws.String("fakeMsgID"), - }, - }, - } - }() consumer := awssqs.NewConsumer(mock, testEndpoint, testReqDecoderfunc, @@ -294,7 +255,10 @@ func TestConsumerSuccess(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(context.Background(), &sqs.Message{ + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }) var receiveOutput *sqs.ReceiveMessageOutput select { @@ -332,16 +296,6 @@ func TestConsumerSuccessNoReply(t *testing.T) { sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: []*sqs.Message{ - { - Body: aws.String(string(b)), - MessageId: aws.String("fakeMsgID"), - }, - }, - } - }() consumer := awssqs.NewConsumer(mock, testEndpoint, testReqDecoderfunc, @@ -349,7 +303,10 @@ func TestConsumerSuccessNoReply(t *testing.T) { queueURL, ) - consumer.Consume(context.Background(), &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(context.Background(), &sqs.Message{ + Body: aws.String(string(b)), + MessageId: aws.String("fakeMsgID"), + }) var receiveOutput *sqs.ReceiveMessageOutput select { @@ -364,79 +321,55 @@ func TestConsumerSuccessNoReply(t *testing.T) { } // TestConsumerBeforeFilterMessages checks if consumer before is called as expected. -// Here before is used to filter messages before processing. -func TestConsumerBeforeFilterMessages(t *testing.T) { - obj1 := testReq{ - Squadron: 436, - } - b1, _ := json.Marshal(obj1) - obj2 := testReq{ - Squadron: 4, - } - b2, _ := json.Marshal(obj2) - obj3 := testReq{ - Squadron: 1, - } - b3, _ := json.Marshal(obj3) +// Here before is used to add a value in context. +func TestConsumerBeforeAddValueToContext(t *testing.T) { queueURL := "someURL" mock := &mockClient{ sendOutputChan: make(chan *sqs.SendMessageOutput), receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } - expectedMsgs := []*sqs.Message{ - { - Body: aws.String(string(b1)), - MessageId: aws.String("fakeMsgID1"), - MessageAttributes: map[string]*sqs.MessageAttributeValue{ - "recipient": { - DataType: aws.String("String"), - StringValue: aws.String("me"), - }, - }, - }, - { - Body: aws.String(string(b2)), - MessageId: aws.String("fakeMsgID2"), - MessageAttributes: map[string]*sqs.MessageAttributeValue{ - "recipient": { - DataType: aws.String("String"), - StringValue: aws.String("not me"), - }, + msg := &sqs.Message{ + Body: aws.String("someBody"), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "recipient": { + DataType: aws.String("String"), + StringValue: aws.String("me"), }, }, - { - Body: aws.String(string(b3)), - MessageId: aws.String("fakeMsgID3"), - }, } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: expectedMsgs, - } - }() type ctxKey struct { key string } consumer := awssqs.NewConsumer(mock, - testEndpoint, - testReqDecoderfunc, - awssqs.EncodeJSONResponse, + // endpoint. + func(ctx context.Context, request interface{}) (interface{}, error) { + return ctx.Value(ctxKey{"recipient"}).(string), nil + }, + // request decoder + func(_ context.Context, msg *sqs.Message) (interface{}, error) { + return *msg.Body, nil + }, + // response encoder + func(_ context.Context, input *sqs.SendMessageInput, response interface{}) error { + input.MessageBody = aws.String(fmt.Sprintf("%v", response)) + return nil + }, queueURL, - awssqs.ConsumerBefore(func(ctx context.Context, msgs *[]*sqs.Message) context.Context { + awssqs.ConsumerBefore(func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message) context.Context { // Filter a message that is not destined to the consumer. - msgsCopy := *msgs - for index, msg := range *msgs { - if recipient, exists := msg.MessageAttributes["recipient"]; !exists || *recipient.StringValue != "me" { - msgsCopy = append(msgsCopy[:index], msgsCopy[index:]...) - } + if recipient, exists := msg.MessageAttributes["recipient"]; exists { + ctx = context.WithValue(ctx, ctxKey{"recipient"}, *recipient.StringValue) } - *msgs = msgsCopy return ctx }), awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) ctx := context.Background() - consumer.Consume(ctx, &sqs.ReceiveMessageInput{}) + err := consumer.ServeMessage(ctx, msg) + if err != nil { + t.Errorf("got err %s", err) + } var receiveOutput *sqs.ReceiveMessageOutput select { @@ -446,14 +379,11 @@ func TestConsumerBeforeFilterMessages(t *testing.T) { case <-time.After(200 * time.Millisecond): t.Fatal("Timed out waiting for publishing") } - res, err := decodeResponse(receiveOutput) - if err != nil { - t.Fatal(err) - } - want := testRes{ - Squadron: 436, - Name: "tusker", + if len(receiveOutput.Messages) != 1 { + t.Errorf("Error : received %d messages instead of 1", len(receiveOutput.Messages)) } + res := *receiveOutput.Messages[0].Body + want := "me" if have := res; want != have { t.Errorf("want %v, have %v", want, have) } @@ -482,23 +412,16 @@ func TestConsumerAfter(t *testing.T) { receiveOuputChan: make(chan *sqs.ReceiveMessageOutput), } correlationID := uuid.NewRandom().String() - expectedMsgs := []*sqs.Message{ - { - Body: aws.String(string(b1)), - MessageId: aws.String("fakeMsgID1"), - MessageAttributes: map[string]*sqs.MessageAttributeValue{ - "correlationID": { - DataType: aws.String("String"), - StringValue: &correlationID, - }, + msg := &sqs.Message{ + Body: aws.String(string(b1)), + MessageId: aws.String("fakeMsgID1"), + MessageAttributes: map[string]*sqs.MessageAttributeValue{ + "correlationID": { + DataType: aws.String("String"), + StringValue: &correlationID, }, }, } - go func() { - mock.receiveOuputChan <- &sqs.ReceiveMessageOutput{ - Messages: expectedMsgs, - } - }() type ctxKey struct { key string } @@ -507,9 +430,7 @@ func TestConsumerAfter(t *testing.T) { testReqDecoderfunc, awssqs.EncodeJSONResponse, queueURL, - awssqs.ConsumerAfter(func(ctx context.Context, msg *sqs.Message, resp *sqs.SendMessageInput, leftMsgs *[]*sqs.Message, mux *sync.Mutex) context.Context { - mux.Lock() - defer mux.Unlock() + awssqs.ConsumerAfter(func(ctx context.Context, cancel context.CancelFunc, msg *sqs.Message, resp *sqs.SendMessageInput) context.Context { if correlationIDAttribute, exists := msg.MessageAttributes["correlationID"]; exists { if resp.MessageAttributes == nil { resp.MessageAttributes = make(map[string]*sqs.MessageAttributeValue) @@ -524,7 +445,7 @@ func TestConsumerAfter(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) ctx := context.Background() - consumer.Consume(ctx, &sqs.ReceiveMessageInput{}) + consumer.ServeMessage(ctx, msg) var receiveOutput *sqs.ReceiveMessageOutput select { diff --git a/transport/awssqs/request_response_func.go b/transport/awssqs/request_response_func.go index 84914268c..478f61f62 100644 --- a/transport/awssqs/request_response_func.go +++ b/transport/awssqs/request_response_func.go @@ -2,7 +2,6 @@ package awssqs import ( "context" - "sync" "github.com/aws/aws-sdk-go/service/sqs" "github.com/aws/aws-sdk-go/service/sqs/sqsiface" @@ -11,8 +10,8 @@ import ( // ConsumerRequestFunc may take information from a consumer request result and // put it into a request context. In Consumers, RequestFuncs are executed prior // to invoking the endpoint. -// use cases eg. in Consumer : extract message into context, or filter received messages. -type ConsumerRequestFunc func(context.Context, *[]*sqs.Message) context.Context +// use cases eg. in Consumer : extract message information into context. +type ConsumerRequestFunc func(ctx context.Context, cancel context.CancelFunc, req *sqs.Message) context.Context // ProducerRequestFunc may take information from a producer request and put it into a // request context, or add some informations to SendMessageInput. In Producers, @@ -25,7 +24,7 @@ type ProducerRequestFunc func(ctx context.Context, input *sqs.SendMessageInput) // consumers, after invoking the endpoint but prior to publishing a reply. // use cases eg. : Pipe information from request message to response MessageInput, // delete msg from queue or update leftMsgs slice. -type ConsumerResponseFunc func(context.Context, *sqs.Message, *sqs.SendMessageInput, *[]*sqs.Message, *sync.Mutex) context.Context +type ConsumerResponseFunc func(ctx context.Context, cancel context.CancelFunc, req *sqs.Message, resp *sqs.SendMessageInput) context.Context // ProducerResponseFunc may take information from an sqs.SendMessageOutput and // fetch response using the Client. SQS is not req-reply out-of-the-box. Responses need to be fetched. From 59c9df11d9fff3a5874c71b28bbc283443df96d2 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 14 Oct 2020 14:54:39 +0200 Subject: [PATCH 16/18] Fix comments punctuation --- transport/awssqs/consumer_test.go | 4 ++-- transport/awssqs/producer.go | 5 +++-- 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index a9c844a68..2e5e02499 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -21,7 +21,7 @@ var ( ) func (mock *mockClient) ReceiveMessageWithContext(ctx context.Context, input *sqs.ReceiveMessageInput, opts ...request.Option) (*sqs.ReceiveMessageOutput, error) { - // Add logic to allow context errors + // Add logic to allow context errors. for { select { case d := <-mock.receiveOuputChan: @@ -342,7 +342,7 @@ func TestConsumerBeforeAddValueToContext(t *testing.T) { key string } consumer := awssqs.NewConsumer(mock, - // endpoint. + // endpoint func(ctx context.Context, request interface{}) (interface{}, error) { return ctx.Value(ctxKey{"recipient"}).(string), nil }, diff --git a/transport/awssqs/producer.go b/transport/awssqs/producer.go index d505b9231..8192d8689 100644 --- a/transport/awssqs/producer.go +++ b/transport/awssqs/producer.go @@ -15,7 +15,7 @@ type contextKey int const ( // ContextKeyResponseQueueURL is the context key that allows fetching - // the response queue URL from context + // and setting the response queue URL from and into context. ContextKeyResponseQueueURL contextKey = iota ) @@ -73,7 +73,8 @@ func ProducerTimeout(timeout time.Duration) ProducerOption { return func(p *Producer) { p.timeout = timeout } } -// SetProducerResponseQueueURL sets this as before or after function +// SetProducerResponseQueueURL can be used as a before function to add +// provided url as responseQueueURL in context. func SetProducerResponseQueueURL(url string) ProducerRequestFunc { return func(ctx context.Context, _ *sqs.SendMessageInput) context.Context { return context.WithValue(ctx, ContextKeyResponseQueueURL, url) From 9d978f3371fa17858c9deca8e2d4631b8800b9c2 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Wed, 14 Oct 2020 15:15:19 +0200 Subject: [PATCH 17/18] Delete VisibilityTimeoutFunc type --- transport/awssqs/consumer.go | 7 ------- 1 file changed, 7 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 534846ff8..02ec54124 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -3,7 +3,6 @@ package awssqs import ( "context" "encoding/json" - "sync" "github.com/aws/aws-sdk-go/aws" "github.com/aws/aws-sdk-go/service/sqs" @@ -195,12 +194,6 @@ type ErrorEncoder func(ctx context.Context, err error, req *sqs.Message, sqsClie // Can also be used to delete messages once fully proccessed. type ConsumerFinalizerFunc func(ctx context.Context, msg *sqs.Message) -// VisibilityTimeoutFunc encapsulates logic to extend messages visibility timeout. -// this can be used to provide custom visibility timeout extension such as doubling it everytime -// it gets close to being reached. -// VisibilityTimeoutFunc will need to check that the provided context is not done and return once it is. -type VisibilityTimeoutFunc func(ctx context.Context, client sqsiface.SQSAPI, queueURL string, visibilityTimeout int64, leftMsgs *[]*sqs.Message, leftMsgsMux *sync.Mutex) error - // WantReplyFunc encapsulates logic to check whether message awaits response or not // for example check for a given message attribute value. type WantReplyFunc func(context.Context, *sqs.Message) bool From b673cbefc4406a0cbfc7f29de081fc62042316f3 Mon Sep 17 00:00:00 2001 From: Omar Qurie Date: Mon, 26 Oct 2020 15:07:40 +0100 Subject: [PATCH 18/18] Make serveMessage return func --- transport/awssqs/consumer.go | 88 ++++++++++++++++--------------- transport/awssqs/consumer_test.go | 16 +++--- 2 files changed, 53 insertions(+), 51 deletions(-) diff --git a/transport/awssqs/consumer.go b/transport/awssqs/consumer.go index 02ec54124..5281151eb 100644 --- a/transport/awssqs/consumer.go +++ b/transport/awssqs/consumer.go @@ -129,57 +129,59 @@ func ConsumerDeleteMessageAfter() ConsumerOption { } // ServeMessage serves an SQS message. -func (c Consumer) ServeMessage(ctx context.Context, msg *sqs.Message) error { - ctx, cancel := context.WithCancel(ctx) - defer cancel() - - if len(c.finalizer) > 0 { - defer func() { - for _, f := range c.finalizer { - f(ctx, msg) - } - }() - } +func (c Consumer) ServeMessage(ctx context.Context) func(msg *sqs.Message) error { + return func(msg *sqs.Message) error { + ctx, cancel := context.WithCancel(ctx) + defer cancel() + + if len(c.finalizer) > 0 { + defer func() { + for _, f := range c.finalizer { + f(ctx, msg) + } + }() + } - for _, f := range c.before { - ctx = f(ctx, cancel, msg) - } + for _, f := range c.before { + ctx = f(ctx, cancel, msg) + } - req, err := c.dec(ctx, msg) - if err != nil { - c.errorHandler.Handle(ctx, err) - c.errorEncoder(ctx, err, msg, c.sqsClient) - return err - } + req, err := c.dec(ctx, msg) + if err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } - response, err := c.e(ctx, req) - if err != nil { - c.errorHandler.Handle(ctx, err) - c.errorEncoder(ctx, err, msg, c.sqsClient) - return err - } + response, err := c.e(ctx, req) + if err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } - responseMsg := sqs.SendMessageInput{} - for _, f := range c.after { - ctx = f(ctx, cancel, msg, &responseMsg) - } + responseMsg := sqs.SendMessageInput{} + for _, f := range c.after { + ctx = f(ctx, cancel, msg, &responseMsg) + } - if !c.wantRep(ctx, msg) { - return nil - } + if !c.wantRep(ctx, msg) { + return nil + } - if err := c.enc(ctx, &responseMsg, response); err != nil { - c.errorHandler.Handle(ctx, err) - c.errorEncoder(ctx, err, msg, c.sqsClient) - return err - } + if err := c.enc(ctx, &responseMsg, response); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } - if _, err := c.sqsClient.SendMessageWithContext(ctx, &responseMsg); err != nil { - c.errorHandler.Handle(ctx, err) - c.errorEncoder(ctx, err, msg, c.sqsClient) - return err + if _, err := c.sqsClient.SendMessageWithContext(ctx, &responseMsg); err != nil { + c.errorHandler.Handle(ctx, err) + c.errorEncoder(ctx, err, msg, c.sqsClient) + return err + } + return nil } - return nil } // ErrorEncoder is responsible for encoding an error to the consumer's reply. diff --git a/transport/awssqs/consumer_test.go b/transport/awssqs/consumer_test.go index 2e5e02499..939eb7286 100644 --- a/transport/awssqs/consumer_test.go +++ b/transport/awssqs/consumer_test.go @@ -64,7 +64,7 @@ func TestConsumerDeleteBefore(t *testing.T) { awssqs.ConsumerDeleteMessageBefore(), ) - consumer.ServeMessage(context.Background(), &sqs.Message{ + consumer.ServeMessage(context.Background())(&sqs.Message{ Body: aws.String("MessageBody"), MessageId: aws.String("fakeMsgID"), }) @@ -113,7 +113,7 @@ func TestConsumerBadDecode(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.ServeMessage(context.Background(), &sqs.Message{ + consumer.ServeMessage(context.Background())(&sqs.Message{ Body: aws.String("MessageBody"), MessageId: aws.String("fakeMsgID"), }) @@ -162,7 +162,7 @@ func TestConsumerBadEndpoint(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.ServeMessage(context.Background(), &sqs.Message{ + consumer.ServeMessage(context.Background())(&sqs.Message{ Body: aws.String("MessageBody"), MessageId: aws.String("fakeMsgID"), }) @@ -211,7 +211,7 @@ func TestConsumerBadEncoder(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.ServeMessage(context.Background(), &sqs.Message{ + consumer.ServeMessage(context.Background())(&sqs.Message{ Body: aws.String("MessageBody"), MessageId: aws.String("fakeMsgID"), }) @@ -255,7 +255,7 @@ func TestConsumerSuccess(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) - consumer.ServeMessage(context.Background(), &sqs.Message{ + consumer.ServeMessage(context.Background())(&sqs.Message{ Body: aws.String(string(b)), MessageId: aws.String("fakeMsgID"), }) @@ -303,7 +303,7 @@ func TestConsumerSuccessNoReply(t *testing.T) { queueURL, ) - consumer.ServeMessage(context.Background(), &sqs.Message{ + consumer.ServeMessage(context.Background())(&sqs.Message{ Body: aws.String(string(b)), MessageId: aws.String("fakeMsgID"), }) @@ -366,7 +366,7 @@ func TestConsumerBeforeAddValueToContext(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) ctx := context.Background() - err := consumer.ServeMessage(ctx, msg) + err := consumer.ServeMessage(ctx)(msg) if err != nil { t.Errorf("got err %s", err) } @@ -445,7 +445,7 @@ func TestConsumerAfter(t *testing.T) { awssqs.ConsumerWantReplyFunc(func(context.Context, *sqs.Message) bool { return true }), ) ctx := context.Background() - consumer.ServeMessage(ctx, msg) + consumer.ServeMessage(ctx)(msg) var receiveOutput *sqs.ReceiveMessageOutput select {