Skip to content

Commit

Permalink
S3 compare-and-exchange implementation (#94150)
Browse files Browse the repository at this point in the history
Adds an implementation of `compareAndExchangeRegister` to
`S3BlobContainer`.
  • Loading branch information
DaveCTurner committed Feb 28, 2023
1 parent 1e405db commit 49d5cd7
Show file tree
Hide file tree
Showing 9 changed files with 901 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,17 @@
package org.elasticsearch.repositories.s3;

import com.amazonaws.AmazonClientException;
import com.amazonaws.services.s3.AmazonS3;
import com.amazonaws.services.s3.model.AbortMultipartUploadRequest;
import com.amazonaws.services.s3.model.AmazonS3Exception;
import com.amazonaws.services.s3.model.CompleteMultipartUploadRequest;
import com.amazonaws.services.s3.model.DeleteObjectsRequest;
import com.amazonaws.services.s3.model.GetObjectRequest;
import com.amazonaws.services.s3.model.InitiateMultipartUploadRequest;
import com.amazonaws.services.s3.model.ListMultipartUploadsRequest;
import com.amazonaws.services.s3.model.ListObjectsRequest;
import com.amazonaws.services.s3.model.MultiObjectDeleteException;
import com.amazonaws.services.s3.model.MultipartUpload;
import com.amazonaws.services.s3.model.ObjectListing;
import com.amazonaws.services.s3.model.ObjectMetadata;
import com.amazonaws.services.s3.model.PartETag;
Expand All @@ -28,6 +33,9 @@
import org.apache.lucene.util.SetOnce;
import org.elasticsearch.ExceptionsHelper;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.action.ActionRunnable;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.common.Randomness;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.blobstore.BlobContainer;
import org.elasticsearch.common.blobstore.BlobPath;
Expand All @@ -41,8 +49,10 @@
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.core.CheckedConsumer;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.repositories.blobstore.ChunkedBlobOutputStream;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.ByteArrayInputStream;
import java.io.IOException;
Expand All @@ -54,11 +64,14 @@
import java.util.List;
import java.util.Map;
import java.util.OptionalLong;
import java.util.concurrent.atomic.AtomicBoolean;
import java.util.concurrent.atomic.AtomicLong;
import java.util.concurrent.atomic.AtomicReference;
import java.util.function.Function;
import java.util.stream.Collectors;

import static org.elasticsearch.common.blobstore.support.BlobContainerUtils.getRegisterBlobContents;
import static org.elasticsearch.common.blobstore.support.BlobContainerUtils.getRegisterUsingConsistentRead;
import static org.elasticsearch.core.Strings.format;
import static org.elasticsearch.repositories.s3.S3Repository.MAX_FILE_SIZE;
import static org.elasticsearch.repositories.s3.S3Repository.MAX_FILE_SIZE_USING_MULTIPART;
Expand Down Expand Up @@ -610,18 +623,212 @@ static Tuple<Long, Long> numberOfMultiparts(final long totalSize, final long par
}
}

@Override
public void compareAndExchangeRegister(String key, long expected, long updated, ActionListener<OptionalLong> listener) {
listener.onFailure(new UnsupportedOperationException()); // TODO
private class CompareAndExchangeOperation {

private final AmazonS3 client;
private final String bucket;
private final String rawKey;
private final String blobKey;
private final ThreadPool threadPool;

CompareAndExchangeOperation(AmazonS3 client, String bucket, String key, ThreadPool threadPool) {
this.client = client;
this.bucket = bucket;
this.rawKey = key;
this.blobKey = buildKey(key);
this.threadPool = threadPool;
}

private List<MultipartUpload> listMultipartUploads() {
final var listRequest = new ListMultipartUploadsRequest(bucket);
listRequest.setPrefix(blobKey);
listRequest.setRequestMetricCollector(blobStore.listMetricCollector);
try {
return SocketAccess.doPrivileged(() -> client.listMultipartUploads(listRequest)).getMultipartUploads();
} catch (AmazonS3Exception e) {
if (e.getStatusCode() == 404) {
return List.of();
}
throw e;
}
}

private int getUploadIndex(String targetUploadId, List<MultipartUpload> multipartUploads) {
var uploadIndex = 0;
var found = false;
for (MultipartUpload multipartUpload : multipartUploads) {
final var observedUploadId = multipartUpload.getUploadId();
if (observedUploadId.equals(targetUploadId)) {
found = true;
} else if (observedUploadId.compareTo(targetUploadId) < 0) {
uploadIndex += 1;
}
}

return found ? uploadIndex : -1;
}

void run(long expected, long updated, ActionListener<OptionalLong> listener) throws Exception {

if (listMultipartUploads().isEmpty() == false) {
// TODO What if the previous writer crashed? We should consider the age of any ongoing uploads before bailing out like this.
listener.onResponse(OptionalLong.empty());
return;
}

final var blobContents = getRegisterBlobContents(updated);

final var initiateRequest = new InitiateMultipartUploadRequest(bucket, blobKey);
initiateRequest.setRequestMetricCollector(blobStore.multiPartUploadMetricCollector);
final var uploadId = SocketAccess.doPrivileged(() -> client.initiateMultipartUpload(initiateRequest)).getUploadId();

final var uploadPartRequest = new UploadPartRequest();
uploadPartRequest.setBucketName(bucket);
uploadPartRequest.setKey(blobKey);
uploadPartRequest.setUploadId(uploadId);
uploadPartRequest.setPartNumber(1);
uploadPartRequest.setLastPart(true);
uploadPartRequest.setInputStream(blobContents.streamInput());
uploadPartRequest.setPartSize(blobContents.length());
uploadPartRequest.setRequestMetricCollector(blobStore.multiPartUploadMetricCollector);
final var partETag = SocketAccess.doPrivileged(() -> client.uploadPart(uploadPartRequest)).getPartETag();

final var currentUploads = listMultipartUploads();
final var uploadIndex = getUploadIndex(uploadId, currentUploads);

if (uploadIndex < 0) {
// already aborted by someone else
listener.onResponse(OptionalLong.empty());
return;
}

final var isComplete = new AtomicBoolean();
final Runnable doCleanup = () -> {
if (isComplete.compareAndSet(false, true)) {
try {
abortMultipartUploadIfExists(uploadId);
} catch (Exception e) {
// cleanup is a best-effort thing, we can't do anything better than log and fall through here
logger.error("unexpected error cleaning up upload [" + uploadId + "] of [" + blobKey + "]", e);
assert false : e;
}
}
};

try (
var listeners = new RefCountingListener(
ActionListener.runAfter(
listener.delegateFailure(
(delegate1, ignored) -> getRegister(
rawKey,
delegate1.delegateFailure((delegate2, currentValue) -> ActionListener.completeWith(delegate2, () -> {
if (currentValue.isPresent() && currentValue.getAsLong() == expected) {
final var completeMultipartUploadRequest = new CompleteMultipartUploadRequest(
bucket,
blobKey,
uploadId,
List.of(partETag)
);
completeMultipartUploadRequest.setRequestMetricCollector(blobStore.multiPartUploadMetricCollector);
SocketAccess.doPrivilegedVoid(() -> client.completeMultipartUpload(completeMultipartUploadRequest));
isComplete.set(true);
}
return currentValue;
}))
)
),
doCleanup
)
)
) {
if (currentUploads.size() > 1) {
// This is a small optimization to improve the liveness properties of this algorithm.
//
// When there are multiple competing updates, we order them by upload id and the first one tries to cancel the competing
// updates in order to make progress. To avoid liveness issues when the winner fails, the rest wait based on their
// upload_id-based position and try to make progress.

var delayListener = listeners.acquire();
final Runnable cancelConcurrentUpdates = () -> {
try {
for (MultipartUpload currentUpload : currentUploads) {
final var currentUploadId = currentUpload.getUploadId();
if (uploadId.equals(currentUploadId) == false) {
threadPool.executor(ThreadPool.Names.SNAPSHOT)
.execute(
ActionRunnable.run(listeners.acquire(), () -> abortMultipartUploadIfExists(currentUploadId))
);
}
}
} finally {
delayListener.onResponse(null);
}
};

if (uploadIndex > 0) {
threadPool.scheduleUnlessShuttingDown(
TimeValue.timeValueMillis(TimeValue.timeValueSeconds(uploadIndex).millis() + Randomness.get().nextInt(50)),
ThreadPool.Names.SNAPSHOT,
cancelConcurrentUpdates
);
} else {
cancelConcurrentUpdates.run();
}
}
}
}

private void abortMultipartUploadIfExists(String uploadId) {
try {
final var request = new AbortMultipartUploadRequest(bucket, blobKey, uploadId);
SocketAccess.doPrivilegedVoid(() -> client.abortMultipartUpload(request));
} catch (AmazonS3Exception e) {
if (e.getStatusCode() != 404) {
throw e;
}
// else already aborted
}
}

}

@Override
public void compareAndSetRegister(String key, long expected, long updated, ActionListener<Boolean> listener) {
listener.onFailure(new UnsupportedOperationException()); // TODO
public void compareAndExchangeRegister(String key, long expected, long updated, ActionListener<OptionalLong> listener) {
final var clientReference = blobStore.clientReference();
ActionListener.run(ActionListener.releaseAfter(listener.delegateResponse((delegate, e) -> {
if (e instanceof AmazonS3Exception amazonS3Exception && amazonS3Exception.getStatusCode() == 404) {
// an uncaught 404 means that our multipart upload was aborted by a concurrent operation before we could complete it
delegate.onResponse(OptionalLong.empty());
} else {
delegate.onFailure(e);
}
}), clientReference),
l -> new CompareAndExchangeOperation(clientReference.client(), blobStore.bucket(), key, blobStore.getThreadPool()).run(
expected,
updated,
l
)
);
}

@Override
public void getRegister(String key, ActionListener<OptionalLong> listener) {
listener.onFailure(new UnsupportedOperationException()); // TODO
ActionListener.completeWith(listener, () -> {
final var getObjectRequest = new GetObjectRequest(blobStore.bucket(), buildKey(key));
getObjectRequest.setRequestMetricCollector(blobStore.getMetricCollector);
try (
var clientReference = blobStore.clientReference();
var s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest));
var stream = s3Object.getObjectContent()
) {
return OptionalLong.of(getRegisterUsingConsistentRead(stream, keyPath, key));
} catch (AmazonS3Exception e) {
if (e.getStatusCode() == 404) {
return OptionalLong.of(0L);
} else {
throw e;
}
}
});
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import org.elasticsearch.common.blobstore.BlobStoreException;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.util.BigArrays;
import org.elasticsearch.threadpool.ThreadPool;

import java.io.IOException;
import java.util.HashMap;
Expand Down Expand Up @@ -51,6 +52,8 @@ class S3BlobStore implements BlobStore {

private final RepositoryMetadata repositoryMetadata;

private final ThreadPool threadPool;

private final Stats stats = new Stats();

final RequestMetricCollector getMetricCollector;
Expand All @@ -66,7 +69,8 @@ class S3BlobStore implements BlobStore {
String cannedACL,
String storageClass,
RepositoryMetadata repositoryMetadata,
BigArrays bigArrays
BigArrays bigArrays,
ThreadPool threadPool
) {
this.service = service;
this.bigArrays = bigArrays;
Expand All @@ -76,6 +80,7 @@ class S3BlobStore implements BlobStore {
this.cannedACL = initCannedACL(cannedACL);
this.storageClass = initStorageClass(storageClass);
this.repositoryMetadata = repositoryMetadata;
this.threadPool = threadPool;
this.getMetricCollector = new IgnoreNoResponseMetricsCollector() {
@Override
public void collectMetrics(Request<?> request) {
Expand Down Expand Up @@ -215,6 +220,10 @@ public static CannedAccessControlList initCannedACL(String cannedACL) {
throw new BlobStoreException("cannedACL is not valid: [" + cannedACL + "]");
}

ThreadPool getThreadPool() {
return threadPool;
}

static class Stats {

final AtomicLong listCount = new AtomicLong();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -408,7 +408,7 @@ private static BlobPath buildBasePath(RepositoryMetadata metadata) {

@Override
protected S3BlobStore createBlobStore() {
return new S3BlobStore(service, bucket, serverSideEncryption, bufferSize, cannedACL, storageClass, metadata, bigArrays);
return new S3BlobStore(service, bucket, serverSideEncryption, bufferSize, cannedACL, storageClass, metadata, bigArrays, threadPool);
}

// only use for testing
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,8 @@ protected BlobContainer createBlobContainer(
S3Repository.CANNED_ACL_SETTING.getDefault(Settings.EMPTY),
S3Repository.STORAGE_CLASS_SETTING.getDefault(Settings.EMPTY),
repositoryMetadata,
BigArrays.NON_RECYCLING_INSTANCE
BigArrays.NON_RECYCLING_INSTANCE,
null
)
) {
@Override
Expand Down

0 comments on commit 49d5cd7

Please sign in to comment.