diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 14a461f..f442be0 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -344,7 +344,9 @@ public ReceiveMessageResult receiveMessage(ReceiveMessageRequest receiveMessageR return super.receiveMessage(receiveMessageRequest); } - receiveMessageRequest.getMessageAttributeNames().add(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + if (!receiveMessageRequest.getMessageAttributeNames().contains(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)) { + receiveMessageRequest.getMessageAttributeNames().add(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + } ReceiveMessageResult receiveMessageResult = super.receiveMessage(receiveMessageRequest); diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index a20b41c..eed1363 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -26,6 +26,8 @@ import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQSClient; import com.amazonaws.services.sqs.model.MessageAttributeValue; +import com.amazonaws.services.sqs.model.ReceiveMessageRequest; +import com.amazonaws.services.sqs.model.ReceiveMessageResult; import com.amazonaws.services.sqs.model.SendMessageBatchRequest; import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry; import com.amazonaws.services.sqs.model.SendMessageRequest; @@ -155,6 +157,24 @@ public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); } + @Test + public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequest() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(new ReceiveMessageResult()); + + ReceiveMessageRequest messageRequest = new ReceiveMessageRequest(); + ReceiveMessageRequest expectedRequest = new ReceiveMessageRequest() + .withMessageAttributeNames(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + + sqsExtended.receiveMessage(messageRequest); + Assert.assertEquals(expectedRequest, messageRequest); + + sqsExtended.receiveMessage(messageRequest); + Assert.assertEquals(expectedRequest, messageRequest); + } + @Test public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { // This creates 10 messages, out of which only two are below the threshold (100K and 200K),