diff --git a/docs/changelog/133030.yaml b/docs/changelog/133030.yaml new file mode 100644 index 0000000000000..aafb4576ed2ad --- /dev/null +++ b/docs/changelog/133030.yaml @@ -0,0 +1,6 @@ +pr: 133030 +summary: Implement `failIfAlreadyExists` in S3 repositories +area: Snapshot/Restore +type: enhancement +issues: + - 128565 diff --git a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java index e431c7e25d250..18ab055ca3374 100644 --- a/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java +++ b/modules/repository-s3/src/internalClusterTest/java/org/elasticsearch/repositories/s3/S3BlobStoreRepositoryTests.java @@ -83,8 +83,11 @@ import java.util.Objects; import java.util.Set; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Executors; import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicBoolean; +import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicLong; import java.util.stream.Collectors; import java.util.stream.StreamSupport; @@ -105,6 +108,7 @@ import static org.hamcrest.Matchers.hasSize; import static org.hamcrest.Matchers.lessThan; import static org.hamcrest.Matchers.not; +import static org.hamcrest.Matchers.startsWith; @SuppressForbidden(reason = "this test uses a HttpServer to emulate an S3 endpoint") // Need to set up a new cluster for each test because cluster settings use randomized authentication settings @@ -425,7 +429,7 @@ public void testEnforcedCooldownPeriod() throws IOException { if (randomBoolean()) { repository.blobStore() .blobContainer(repository.basePath()) - .writeBlobAtomic(randomNonDataPurpose(), getRepositoryDataBlobName(modifiedRepositoryData.getGenId()), serialized, true); + .writeBlobAtomic(randomNonDataPurpose(), getRepositoryDataBlobName(modifiedRepositoryData.getGenId()), serialized, false); } else { repository.blobStore() .blobContainer(repository.basePath()) @@ -434,7 +438,7 @@ public void testEnforcedCooldownPeriod() throws IOException { getRepositoryDataBlobName(modifiedRepositoryData.getGenId()), serialized.streamInput(), serialized.length(), - true + false ); } @@ -568,6 +572,52 @@ public void match(LogEvent event) { } } + public void testFailIfAlreadyExists() throws IOException, InterruptedException { + try (BlobStore store = newBlobStore()) { + final BlobContainer container = store.blobContainer(BlobPath.EMPTY); + final String blobName = randomAlphaOfLengthBetween(8, 12); + + final byte[] data; + if (randomBoolean()) { + // single upload + data = randomBytes(randomIntBetween(10, scaledRandomIntBetween(1024, 1 << 16))); + } else { + // multipart upload + int thresholdInBytes = Math.toIntExact(((S3BlobContainer) container).getLargeBlobThresholdInBytes()); + data = randomBytes(randomIntBetween(thresholdInBytes, thresholdInBytes + scaledRandomIntBetween(1024, 1 << 16))); + } + + // initial write blob + AtomicInteger exceptionCount = new AtomicInteger(0); + try (var executor = Executors.newFixedThreadPool(2)) { + for (int i = 0; i < 2; i++) { + executor.submit(() -> { + try { + writeBlob(container, blobName, new BytesArray(data), true); + } catch (IOException e) { + exceptionCount.incrementAndGet(); + } + }); + } + executor.shutdown(); + var done = executor.awaitTermination(1, TimeUnit.SECONDS); + assertTrue(done); + } + + assertEquals(1, exceptionCount.get()); + + // overwrite if failIfAlreadyExists is set to false + writeBlob(container, blobName, new BytesArray(data), false); + + // throw exception if failIfAlreadyExists is set to true + var exception = expectThrows(IOException.class, () -> writeBlob(container, blobName, new BytesArray(data), true)); + + assertThat(exception.getMessage(), startsWith("Unable to upload")); + + container.delete(randomPurpose()); + } + } + /** * S3RepositoryPlugin that allows to disable chunked encoding and to set a low threshold between single upload and multipart upload. */ diff --git a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java index f7b910bfb2a32..dcd3a7dbe6533 100644 --- a/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java +++ b/modules/repository-s3/src/main/java/org/elasticsearch/repositories/s3/S3BlobContainer.java @@ -137,18 +137,15 @@ public long readBlobPreferredLength() { return ByteSizeValue.of(32, ByteSizeUnit.MB).getBytes(); } - /** - * This implementation ignores the failIfAlreadyExists flag as the S3 API has no way to enforce this due to its weak consistency model. - */ @Override public void writeBlob(OperationPurpose purpose, String blobName, InputStream inputStream, long blobSize, boolean failIfAlreadyExists) throws IOException { assert BlobContainer.assertPurposeConsistency(purpose, blobName); assert inputStream.markSupported() : "No mark support on inputStream breaks the S3 SDK's ability to retry requests"; if (blobSize <= getLargeBlobThresholdInBytes()) { - executeSingleUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize); + executeSingleUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize, failIfAlreadyExists); } else { - executeMultipartUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize); + executeMultipartUpload(purpose, blobStore, buildKey(blobName), inputStream, blobSize, failIfAlreadyExists); } } @@ -545,7 +542,8 @@ void executeSingleUpload( final S3BlobStore s3BlobStore, final String blobName, final InputStream input, - final long blobSize + final long blobSize, + final boolean failIfAlreadyExists ) throws IOException { try (var clientReference = s3BlobStore.clientReference()) { // Extra safety checks @@ -565,6 +563,9 @@ void executeSingleUpload( if (s3BlobStore.serverSideEncryption()) { putRequestBuilder.serverSideEncryption(ServerSideEncryption.AES256); } + if (failIfAlreadyExists) { + putRequestBuilder.ifNoneMatch("*"); + } S3BlobStore.configureRequestForMetrics(putRequestBuilder, blobStore, Operation.PUT_OBJECT, purpose); final var putRequest = putRequestBuilder.build(); @@ -586,7 +587,8 @@ private void executeMultipart( final String blobName, final long partSize, final long blobSize, - final PartOperation partOperation + final PartOperation partOperation, + final boolean failIfAlreadyExists ) throws IOException { ensureMultiPartUploadSize(blobSize); @@ -639,6 +641,11 @@ private void executeMultipart( .key(blobName) .uploadId(uploadId) .multipartUpload(b -> b.parts(parts)); + + if (failIfAlreadyExists) { + completeMultipartUploadRequestBuilder.ifNoneMatch("*"); + } + S3BlobStore.configureRequestForMetrics(completeMultipartUploadRequestBuilder, blobStore, operation, purpose); final var completeMultipartUploadRequest = completeMultipartUploadRequestBuilder.build(); try (var clientReference = s3BlobStore.clientReference()) { @@ -663,7 +670,8 @@ void executeMultipartUpload( final S3BlobStore s3BlobStore, final String blobName, final InputStream input, - final long blobSize + final long blobSize, + final boolean failIfAlreadyExists ) throws IOException { executeMultipart( purpose, @@ -680,7 +688,8 @@ void executeMultipartUpload( .uploadPart(uploadRequest, RequestBody.fromInputStream(input, partSize)); return CompletedPart.builder().partNumber(partNum).eTag(uploadResponse.eTag()).build(); } - } + }, + failIfAlreadyExists ); } @@ -727,7 +736,8 @@ void executeMultipartCopy( final var uploadPartCopyResponse = clientReference.client().uploadPartCopy(uploadPartCopyRequest); return CompletedPart.builder().partNumber(partNum).eTag(uploadPartCopyResponse.copyPartResult().eTag()).build(); } - }) + }), + false ); } diff --git a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobStoreContainerTests.java b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobStoreContainerTests.java index 7a0e9fc2de855..90c9793921f95 100644 --- a/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobStoreContainerTests.java +++ b/modules/repository-s3/src/test/java/org/elasticsearch/repositories/s3/S3BlobStoreContainerTests.java @@ -69,7 +69,14 @@ public void testExecuteSingleUploadBlobSizeTooLarge() { final IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> blobContainer.executeSingleUpload(randomPurpose(), blobStore, randomAlphaOfLengthBetween(1, 10), null, blobSize) + () -> blobContainer.executeSingleUpload( + randomPurpose(), + blobStore, + randomAlphaOfLengthBetween(1, 10), + null, + blobSize, + randomBoolean() + ) ); assertEquals("Upload request size [" + blobSize + "] can't be larger than 5gb", e.getMessage()); } @@ -88,7 +95,8 @@ public void testExecuteSingleUploadBlobSizeLargerThanBufferSize() { blobStore, blobName, new ByteArrayInputStream(new byte[0]), - ByteSizeUnit.MB.toBytes(2) + ByteSizeUnit.MB.toBytes(2), + randomBoolean() ) ); assertEquals("Upload request size [2097152] can't be larger than buffer size", e.getMessage()); @@ -123,6 +131,8 @@ public void testExecuteSingleUpload() throws IOException { when(blobStore.getCannedACL()).thenReturn(cannedAccessControlList); } + final boolean failIfAlreadyExists = randomBoolean(); + final S3Client client = configureMockClient(blobStore); final ArgumentCaptor requestCaptor = ArgumentCaptor.forClass(PutObjectRequest.class); @@ -131,7 +141,7 @@ public void testExecuteSingleUpload() throws IOException { when(client.putObject(requestCaptor.capture(), bodyCaptor.capture())).thenReturn(PutObjectResponse.builder().build()); final ByteArrayInputStream inputStream = new ByteArrayInputStream(new byte[blobSize]); - blobContainer.executeSingleUpload(randomPurpose(), blobStore, blobName, inputStream, blobSize); + blobContainer.executeSingleUpload(randomPurpose(), blobStore, blobName, inputStream, blobSize, failIfAlreadyExists); final PutObjectRequest request = requestCaptor.getValue(); assertEquals(bucketName, request.bucket()); @@ -147,6 +157,10 @@ public void testExecuteSingleUpload() throws IOException { ); } + if (failIfAlreadyExists) { + assertEquals("*", request.ifNoneMatch()); + } + final RequestBody requestBody = bodyCaptor.getValue(); try (var contentStream = requestBody.contentStreamProvider().newStream()) { assertEquals(inputStream.available(), blobSize); @@ -164,7 +178,14 @@ public void testExecuteMultipartUploadBlobSizeTooLarge() { final IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> blobContainer.executeMultipartUpload(randomPurpose(), blobStore, randomAlphaOfLengthBetween(1, 10), null, blobSize) + () -> blobContainer.executeMultipartUpload( + randomPurpose(), + blobStore, + randomAlphaOfLengthBetween(1, 10), + null, + blobSize, + randomBoolean() + ) ); assertEquals("Multipart upload request size [" + blobSize + "] can't be larger than 5tb", e.getMessage()); } @@ -176,7 +197,14 @@ public void testExecuteMultipartUploadBlobSizeTooSmall() { final IllegalArgumentException e = expectThrows( IllegalArgumentException.class, - () -> blobContainer.executeMultipartUpload(randomPurpose(), blobStore, randomAlphaOfLengthBetween(1, 10), null, blobSize) + () -> blobContainer.executeMultipartUpload( + randomPurpose(), + blobStore, + randomAlphaOfLengthBetween(1, 10), + null, + blobSize, + randomBoolean() + ) ); assertEquals("Multipart upload request size [" + blobSize + "] can't be smaller than 5mb", e.getMessage()); } @@ -225,6 +253,8 @@ void testExecuteMultipart(boolean doCopy) throws IOException { when(blobStore.getCannedACL()).thenReturn(cannedAccessControlList); } + final boolean failIfAlreadyExists = doCopy ? false : randomBoolean(); + final S3Client client = configureMockClient(blobStore); final var uploadId = randomIdentifier(); @@ -273,7 +303,7 @@ void testExecuteMultipart(boolean doCopy) throws IOException { if (doCopy) { blobContainer.executeMultipartCopy(randomPurpose(), sourceContainer, sourceBlobName, blobName, blobSize); } else { - blobContainer.executeMultipartUpload(randomPurpose(), blobStore, blobName, inputStream, blobSize); + blobContainer.executeMultipartUpload(randomPurpose(), blobStore, blobName, inputStream, blobSize, failIfAlreadyExists); } final CreateMultipartUploadRequest initRequest = createMultipartUploadRequestCaptor.getValue(); @@ -340,6 +370,10 @@ void testExecuteMultipart(boolean doCopy) throws IOException { assertEquals(blobPath.buildAsString() + blobName, compRequest.key()); assertEquals(uploadId, compRequest.uploadId()); + if (failIfAlreadyExists) { + assertEquals("*", compRequest.ifNoneMatch()); + } + final List actualETags = compRequest.multipartUpload() .parts() .stream() @@ -419,7 +453,14 @@ public void close() {} final IOException e = expectThrows(IOException.class, () -> { final S3BlobContainer blobContainer = new S3BlobContainer(BlobPath.EMPTY, blobStore); - blobContainer.executeMultipartUpload(randomPurpose(), blobStore, blobName, new ByteArrayInputStream(new byte[0]), blobSize); + blobContainer.executeMultipartUpload( + randomPurpose(), + blobStore, + blobName, + new ByteArrayInputStream(new byte[0]), + blobSize, + randomBoolean() + ); }); assertEquals("Unable to upload or copy object [" + blobName + "] using multipart upload", e.getMessage()); diff --git a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsRepositoryTests.java b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsRepositoryTests.java index 7961ca0257be8..3d75d9915bf75 100644 --- a/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsRepositoryTests.java +++ b/plugins/repository-hdfs/src/test/java/org/elasticsearch/repositories/hdfs/HdfsRepositoryTests.java @@ -62,4 +62,9 @@ protected void assertCleanupResponse(CleanupRepositoryResponse response, long by assertThat(response.result().blobs(), equalTo(0L)); } } + + @Override + public void testFailIfAlreadyExists() { + // HDFS does not implement failIfAlreadyExists correctly + } } diff --git a/test/fixtures/s3-fixture/src/main/java/fixture/s3/S3HttpHandler.java b/test/fixtures/s3-fixture/src/main/java/fixture/s3/S3HttpHandler.java index 5b0a0b03d7af4..bf7c864fd1c36 100644 --- a/test/fixtures/s3-fixture/src/main/java/fixture/s3/S3HttpHandler.java +++ b/test/fixtures/s3-fixture/src/main/java/fixture/s3/S3HttpHandler.java @@ -188,8 +188,9 @@ public void handle(final HttpExchange exchange) throws IOException { } else if (request.isCompleteMultipartUploadRequest()) { final byte[] responseBody; + boolean preconditionFailed = false; synchronized (uploads) { - final var upload = removeUpload(request.getQueryParamOnce("uploadId")); + final var upload = getUpload(request.getQueryParamOnce("uploadId")); if (upload == null) { if (Randomness.get().nextBoolean()) { responseBody = null; @@ -205,19 +206,35 @@ public void handle(final HttpExchange exchange) throws IOException { } } else { final var blobContents = upload.complete(extractPartEtags(Streams.readFully(exchange.getRequestBody()))); - blobs.put(request.path(), blobContents); - responseBody = ("\n" - + "\n" - + "" - + bucket - + "\n" - + "" - + request.path() - + "\n" - + "").getBytes(StandardCharsets.UTF_8); + + if (isProtectOverwrite(exchange)) { + var previousValue = blobs.putIfAbsent(request.path(), blobContents); + if (previousValue != null) { + preconditionFailed = true; + } + } else { + blobs.put(request.path(), blobContents); + } + + if (preconditionFailed == false) { + responseBody = ("\n" + + "\n" + + "" + + bucket + + "\n" + + "" + + request.path() + + "\n" + + "").getBytes(StandardCharsets.UTF_8); + removeUpload(upload.getUploadId()); + } else { + responseBody = null; + } } } - if (responseBody == null) { + if (preconditionFailed) { + exchange.sendResponseHeaders(RestStatus.PRECONDITION_FAILED.getStatus(), -1); + } else if (responseBody == null) { exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1); } else { exchange.getResponseHeaders().add("Content-Type", "application/xml"); @@ -232,6 +249,10 @@ public void handle(final HttpExchange exchange) throws IOException { // a copy request is a put request with an X-amz-copy-source header final var copySource = copySourceName(exchange); if (copySource != null) { + if (isProtectOverwrite(exchange)) { + throw new AssertionError("If-None-Match: * header is not supported here"); + } + var sourceBlob = blobs.get(copySource); if (sourceBlob == null) { exchange.sendResponseHeaders(RestStatus.NOT_FOUND.getStatus(), -1); @@ -247,9 +268,22 @@ public void handle(final HttpExchange exchange) throws IOException { } } else { final Tuple blob = parseRequestBody(exchange); - blobs.put(request.path(), blob.v2()); - exchange.getResponseHeaders().add("ETag", blob.v1()); - exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1); + boolean preconditionFailed = false; + if (isProtectOverwrite(exchange)) { + var previousValue = blobs.putIfAbsent(request.path(), blob.v2()); + if (previousValue != null) { + preconditionFailed = true; + } + } else { + blobs.put(request.path(), blob.v2()); + } + + if (preconditionFailed) { + exchange.sendResponseHeaders(RestStatus.PRECONDITION_FAILED.getStatus(), -1); + } else { + exchange.getResponseHeaders().add("ETag", blob.v1()); + exchange.sendResponseHeaders(RestStatus.OK.getStatus(), -1); + } } } else if (request.isListObjectsRequest()) { @@ -539,6 +573,24 @@ private static HttpHeaderParser.Range parsePartRange(final HttpExchange exchange return parseRangeHeader(sourceRangeHeaders.getFirst()); } + private static boolean isProtectOverwrite(final HttpExchange exchange) { + final var ifNoneMatch = exchange.getRequestHeaders().get("If-None-Match"); + + if (ifNoneMatch == null) { + return false; + } + + if (ifNoneMatch.size() != 1) { + throw new AssertionError("multiple If-None-Match headers found: " + ifNoneMatch); + } + + if (ifNoneMatch.getFirst().equals("*")) { + return true; + } + + throw new AssertionError("invalid If-None-Match header: " + ifNoneMatch); + } + MultipartUpload putUpload(String path) { final var upload = new MultipartUpload(UUIDs.randomBase64UUID(), path); synchronized (uploads) { diff --git a/test/fixtures/s3-fixture/src/test/java/fixture/s3/S3HttpHandlerTests.java b/test/fixtures/s3-fixture/src/test/java/fixture/s3/S3HttpHandlerTests.java index f9e36eacf0a77..5da274f798333 100644 --- a/test/fixtures/s3-fixture/src/test/java/fixture/s3/S3HttpHandlerTests.java +++ b/test/fixtures/s3-fixture/src/test/java/fixture/s3/S3HttpHandlerTests.java @@ -32,10 +32,14 @@ import java.util.ArrayList; import java.util.List; import java.util.Objects; +import java.util.concurrent.Executors; +import java.util.concurrent.TimeUnit; +import java.util.function.Consumer; import static org.hamcrest.Matchers.allOf; import static org.hamcrest.Matchers.containsString; import static org.hamcrest.Matchers.greaterThan; +import static org.hamcrest.Matchers.hasSize; public class S3HttpHandlerTests extends ESTestCase { @@ -383,6 +387,91 @@ public void testExtractPartEtags() { } + public void testPreventObjectOverwrite() throws InterruptedException { + final var handler = new S3HttpHandler("bucket", "path"); + + var tasks = List.of( + createPutObjectTask(handler), + createPutObjectTask(handler), + createMultipartUploadTask(handler), + createMultipartUploadTask(handler) + ); + + try (var executor = Executors.newVirtualThreadPerTaskExecutor()) { + tasks.forEach(task -> executor.submit(task.consumer)); + executor.shutdown(); + var done = executor.awaitTermination(SAFE_AWAIT_TIMEOUT.seconds(), TimeUnit.SECONDS); + assertTrue(done); + } + + List successfulTasks = tasks.stream().filter(task -> task.status == RestStatus.OK).toList(); + assertThat(successfulTasks, hasSize(1)); + + tasks.stream().filter(task -> task.uploadId != null).forEach(task -> { + if (task.status == RestStatus.PRECONDITION_FAILED) { + assertNotNull(handler.getUpload(task.uploadId)); + } else { + assertNull(handler.getUpload(task.uploadId)); + } + }); + + assertEquals( + new TestHttpResponse(RestStatus.OK, successfulTasks.getFirst().body, TestHttpExchange.EMPTY_HEADERS), + handleRequest(handler, "GET", "/bucket/path/blob") + ); + } + + private static TestWriteTask createPutObjectTask(S3HttpHandler handler) { + return new TestWriteTask( + (task) -> task.status = handleRequest(handler, "PUT", "/bucket/path/blob", task.body, ifNoneMatchHeader()).status() + ); + } + + private static TestWriteTask createMultipartUploadTask(S3HttpHandler handler) { + final var multipartUploadTask = new TestWriteTask( + (task) -> task.status = handleRequest( + handler, + "POST", + "/bucket/path/blob?uploadId=" + task.uploadId, + new BytesArray(Strings.format(""" + + + + %s + 1 + + """, task.etag)), + ifNoneMatchHeader() + ).status() + ); + + final var createUploadResponse = handleRequest(handler, "POST", "/bucket/path/blob?uploads"); + multipartUploadTask.uploadId = getUploadId(createUploadResponse.body()); + + final var uploadPart1Response = handleRequest( + handler, + "PUT", + "/bucket/path/blob?uploadId=" + multipartUploadTask.uploadId + "&partNumber=1", + multipartUploadTask.body + ); + multipartUploadTask.etag = Objects.requireNonNull(uploadPart1Response.etag()); + + return multipartUploadTask; + } + + private static class TestWriteTask { + final BytesReference body; + final Runnable consumer; + String uploadId; + String etag; + RestStatus status; + + TestWriteTask(Consumer consumer) { + this.body = randomBytesReference(50); + this.consumer = () -> consumer.accept(this); + } + } + private void runExtractPartETagsTest(String body, String... expectedTags) { assertEquals(List.of(expectedTags), S3HttpHandler.extractPartEtags(new BytesArray(body.getBytes(StandardCharsets.UTF_8)))); } @@ -467,6 +556,12 @@ private static Headers contentRangeHeader(long start, long end, long length) { return headers; } + private static Headers ifNoneMatchHeader() { + var headers = new Headers(); + headers.put("If-None-Match", List.of("*")); + return headers; + } + private static class TestHttpExchange extends HttpExchange { private static final Headers EMPTY_HEADERS = new Headers(); diff --git a/test/framework/src/main/java/org/elasticsearch/repositories/AbstractThirdPartyRepositoryTestCase.java b/test/framework/src/main/java/org/elasticsearch/repositories/AbstractThirdPartyRepositoryTestCase.java index b9f2e797f71dd..8cd3aa1b15dfc 100644 --- a/test/framework/src/main/java/org/elasticsearch/repositories/AbstractThirdPartyRepositoryTestCase.java +++ b/test/framework/src/main/java/org/elasticsearch/repositories/AbstractThirdPartyRepositoryTestCase.java @@ -346,6 +346,55 @@ public void testSkipBeyondBlobLengthShouldThrowEOFException() throws IOException } } + public void testFailIfAlreadyExists() { + final var blobName = randomIdentifier(); + final int blobLength = randomIntBetween(100, 2_000); + final var initialBlobBytes = randomBytesReference(blobLength); + final var overwriteBlobBytes = randomBytesReference(blobLength); + + final var repository = getRepository(); + + CheckedFunction initialWrite = blobStore -> { + blobStore.writeBlobAtomic(randomPurpose(), blobName, initialBlobBytes, true); + return null; + }; + + // initial write blob + var initialWrite1 = submitOnBlobStore(repository, initialWrite); + var initialWrite2 = submitOnBlobStore(repository, initialWrite); + + Exception ex1 = null; + Exception ex2 = null; + + try { + initialWrite1.actionGet(); + } catch (Exception e) { + ex1 = e; + } + + try { + initialWrite2.actionGet(); + } catch (Exception e) { + ex2 = e; + } + + assertTrue("Exactly one of the writes must succeed", (ex1 == null) != (ex2 == null)); + + // override if failIfAlreadyExists is set to false + executeOnBlobStore(repository, blobStore -> { + blobStore.writeBlob(randomPurpose(), blobName, overwriteBlobBytes, false); + return null; + }); + + assertEquals(overwriteBlobBytes, readBlob(repository, blobName, 0, overwriteBlobBytes.length())); + + // throw exception if failIfAlreadyExists is set to true + executeOnBlobStore(repository, blobStore -> { + expectThrows(Exception.class, () -> blobStore.writeBlob(randomPurpose(), blobName, initialBlobBytes, true)); + return null; + }); + } + protected void testReadFromPositionLargerThanBlobLength(Predicate responseCodeChecker) { final var blobName = randomIdentifier(); final var blobBytes = randomBytesReference(randomIntBetween(100, 2_000)); @@ -381,12 +430,20 @@ protected void testReadFromPositionLargerThanBlobLength(Predicate T executeOnBlobStore(BlobStoreRepository repository, CheckedFunction fn) { + protected static PlainActionFuture submitOnBlobStore( + BlobStoreRepository repository, + CheckedFunction fn + ) { final var future = new PlainActionFuture(); repository.threadPool().generic().execute(ActionRunnable.supply(future, () -> { var blobContainer = repository.blobStore().blobContainer(repository.basePath()); return fn.apply(blobContainer); })); + return future; + } + + protected static T executeOnBlobStore(BlobStoreRepository repository, CheckedFunction fn) { + final var future = submitOnBlobStore(repository, fn); return future.actionGet(); }