diff --git a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java index 737888a..42dcd75 100644 --- a/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java +++ b/src/main/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClient.java @@ -344,7 +344,9 @@ public ReceiveMessageResult receiveMessage(ReceiveMessageRequest receiveMessageR return super.receiveMessage(receiveMessageRequest); } - receiveMessageRequest.getMessageAttributeNames().add(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + if (!receiveMessageRequest.getMessageAttributeNames().contains(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME)) { + receiveMessageRequest.getMessageAttributeNames().add(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + } ReceiveMessageResult receiveMessageResult = super.receiveMessage(receiveMessageRequest); diff --git a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java index 60df2b2..2b25930 100644 --- a/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java +++ b/src/test/java/com/amazon/sqs/javamessaging/AmazonSQSExtendedClientTest.java @@ -25,6 +25,8 @@ import com.amazonaws.services.sqs.AmazonSQS; import com.amazonaws.services.sqs.AmazonSQSClient; import com.amazonaws.services.sqs.model.MessageAttributeValue; +import com.amazonaws.services.sqs.model.ReceiveMessageRequest; +import com.amazonaws.services.sqs.model.ReceiveMessageResult; import com.amazonaws.services.sqs.model.SendMessageBatchRequest; import com.amazonaws.services.sqs.model.SendMessageBatchRequestEntry; import com.amazonaws.services.sqs.model.SendMessageRequest; @@ -138,6 +140,24 @@ public void testWhenSendMessageWithSetMessageSizeThresholdThenThresholdIsHonored verify(mockS3, times(1)).putObject(isA(PutObjectRequest.class)); } + @Test + public void testReceiveMessageMultipleTimesDoesNotAdditionallyAlterReceiveMessageRequest() { + ExtendedClientConfiguration extendedClientConfiguration = new ExtendedClientConfiguration() + .withLargePayloadSupportEnabled(mockS3, S3_BUCKET_NAME); + AmazonSQS sqsExtended = spy(new AmazonSQSExtendedClient(mockSqsBackend, extendedClientConfiguration)); + when(mockSqsBackend.receiveMessage(isA(ReceiveMessageRequest.class))).thenReturn(new ReceiveMessageResult()); + + ReceiveMessageRequest messageRequest = new ReceiveMessageRequest(); + ReceiveMessageRequest expectedRequest = new ReceiveMessageRequest() + .withMessageAttributeNames(SQSExtendedClientConstants.RESERVED_ATTRIBUTE_NAME); + + sqsExtended.receiveMessage(messageRequest); + Assert.assertEquals(expectedRequest, messageRequest); + + sqsExtended.receiveMessage(messageRequest); + Assert.assertEquals(expectedRequest, messageRequest); + } + @Test public void testWhenMessageBatchIsSentThenOnlyMessagesLargerThanThresholdAreStoredInS3() { // This creates 10 messages, out of which only two are below the threshold (100K and 200K),