Skip to content

Commit

Permalink
Ensure CCR partial reads never overuse buffer (#58620)
Browse files Browse the repository at this point in the history
When the documents are large, a follower can receive a partial response
because the requesting range of operations is capped by
max_read_request_size instead of max_read_request_operation_count. In
this case, the follower will continue reading the subsequent ranges
without checking the remaining size of the buffer. The buffer then can
use more memory than max_write_buffer_size and even causes OOM.

Backport of #58620
  • Loading branch information
dnhatn committed Jul 1, 2020
1 parent f57743e commit 138e3c7
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 11 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public abstract class ShardFollowNodeTask extends AllocatedPersistentTask {
private long failedWriteRequests = 0;
private long operationWritten = 0;
private long lastFetchTime = -1;
private final Queue<Tuple<Long, Long>> partialReadRequests = new PriorityQueue<>(Comparator.comparing(Tuple::v1));
private final Queue<Translog.Operation> buffer = new PriorityQueue<>(Comparator.comparing(Translog.Operation::seqNo));
private long bufferSizeInBytes = 0;
private final LinkedHashMap<Long, Tuple<AtomicInteger, ElasticsearchException>> fetchExceptions;
Expand Down Expand Up @@ -175,6 +176,20 @@ synchronized void coordinateReads() {

LOGGER.trace("{} coordinate reads, lastRequestedSeqNo={}, leaderGlobalCheckpoint={}",
params.getFollowShardId(), lastRequestedSeqNo, leaderGlobalCheckpoint);
assert partialReadRequests.size() <= params.getMaxOutstandingReadRequests() :
"too many partial read requests [" + partialReadRequests + "]";
while (hasReadBudget() && partialReadRequests.isEmpty() == false) {
final Tuple<Long, Long> range = partialReadRequests.remove();
assert range.v1() <= range.v2() && range.v2() <= lastRequestedSeqNo :
"invalid partial range [" + range.v1() + "," + range.v2() + "]; last requested seq_no [" + lastRequestedSeqNo + "]";
final long fromSeqNo = range.v1();
final long maxRequiredSeqNo = range.v2();
final int requestOpCount = Math.toIntExact(maxRequiredSeqNo - fromSeqNo + 1);
LOGGER.trace("{}[{} ongoing reads] continue partial read request from_seqno={} max_required_seqno={} batch_count={}",
params.getFollowShardId(), numOutstandingReads, fromSeqNo, maxRequiredSeqNo, requestOpCount);
numOutstandingReads++;
sendShardChangesRequest(fromSeqNo, requestOpCount, maxRequiredSeqNo);
}
final int maxReadRequestOperationCount = params.getMaxReadRequestOperationCount();
while (hasReadBudget() && lastRequestedSeqNo < leaderGlobalCheckpoint) {
final long from = lastRequestedSeqNo + 1;
Expand All @@ -190,8 +205,8 @@ synchronized void coordinateReads() {
LOGGER.trace("{}[{} ongoing reads] read from_seqno={} max_required_seqno={} batch_count={}",
params.getFollowShardId(), numOutstandingReads, from, maxRequiredSeqNo, requestOpCount);
numOutstandingReads++;
sendShardChangesRequest(from, requestOpCount, maxRequiredSeqNo);
lastRequestedSeqNo = maxRequiredSeqNo;
sendShardChangesRequest(from, requestOpCount, maxRequiredSeqNo);
}

if (numOutstandingReads == 0 && hasReadBudget()) {
Expand All @@ -207,6 +222,9 @@ synchronized void coordinateReads() {

private boolean hasReadBudget() {
assert Thread.holdsLock(this);
// TODO: To ensure that we never overuse the buffer, we need to
// - Overestimate the size and count of the responses of the outstanding request when calculating the budget
// - Limit the size and count of next read requests by the remaining size and count of the buffer
if (numOutstandingReads >= params.getMaxOutstandingReadRequests()) {
LOGGER.trace("{} no new reads, maximum number of concurrent reads have been reached [{}]",
params.getFollowShardId(), numOutstandingReads);
Expand All @@ -216,7 +234,7 @@ private boolean hasReadBudget() {
LOGGER.trace("{} no new reads, buffer size limit has been reached [{}]", params.getFollowShardId(), bufferSizeInBytes);
return false;
}
if (buffer.size() > params.getMaxWriteBufferCount()) {
if (buffer.size() >= params.getMaxWriteBufferCount()) {
LOGGER.trace("{} no new reads, buffer count limit has been reached [{}]", params.getFollowShardId(), buffer.size());
return false;
}
Expand Down Expand Up @@ -359,16 +377,13 @@ synchronized void innerHandleReadResponse(long from, long maxRequiredSeqNo, Shar
"] is larger than the global checkpoint [" + leaderGlobalCheckpoint + "]";
coordinateWrites();
}
if (newFromSeqNo <= maxRequiredSeqNo && isStopped() == false) {
int newSize = Math.toIntExact(maxRequiredSeqNo - newFromSeqNo + 1);
LOGGER.trace("{} received [{}] ops, still missing [{}/{}], continuing to read...",
if (newFromSeqNo <= maxRequiredSeqNo) {
LOGGER.trace("{} received [{}] operations, enqueue partial read request [{}/{}]",
params.getFollowShardId(), response.getOperations().length, newFromSeqNo, maxRequiredSeqNo);
sendShardChangesRequest(newFromSeqNo, newSize, maxRequiredSeqNo);
} else {
// read is completed, decrement
numOutstandingReads--;
coordinateReads();
partialReadRequests.add(Tuple.tuple(newFromSeqNo, maxRequiredSeqNo));
}
numOutstandingReads--;
coordinateReads();
}

private void sendBulkShardOperationsRequest(List<Translog.Operation> operations, long leaderMaxSeqNoOfUpdatesOrDeletes,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import org.elasticsearch.common.collect.ImmutableOpenMap;
import org.elasticsearch.common.network.NetworkModule;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.unit.ByteSizeValue;
import org.elasticsearch.common.unit.TimeValue;
import org.elasticsearch.common.xcontent.XContentBuilder;
import org.elasticsearch.core.internal.io.IOUtils;
Expand Down Expand Up @@ -480,6 +481,8 @@ public static PutFollowAction.Request putFollow(String leaderIndex, String follo
request.setFollowerIndex(followerIndex);
request.getParameters().setMaxRetryDelay(TimeValue.timeValueMillis(10));
request.getParameters().setReadPollTimeout(TimeValue.timeValueMillis(10));
request.getParameters().setMaxReadRequestSize(new ByteSizeValue(between(1, 32 * 1024 * 1024)));
request.getParameters().setMaxReadRequestOperationCount(between(1, 10000));
request.waitForActiveShards(waitForActiveShards);
return request;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.elasticsearch.xpack.ccr.action;

import org.elasticsearch.ElasticsearchException;
import org.elasticsearch.action.ActionListener;
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.collect.Tuple;
import org.elasticsearch.common.settings.Settings;
Expand Down Expand Up @@ -77,6 +78,7 @@ public class ShardFollowNodeTaskTests extends ESTestCase {
private Queue<Long> followerGlobalCheckpoints;
private Queue<Long> maxSeqNos;
private Queue<Integer> responseSizes;
private Queue<ActionListener<BulkShardOperationsResponse>> pendingBulkShardRequests;

public void testCoordinateReads() {
ShardFollowTaskParams params = new ShardFollowTaskParams();
Expand Down Expand Up @@ -597,6 +599,55 @@ public void testReceiveNothingExpectedSomething() {
assertThat(status.leaderGlobalCheckpoint(), equalTo(63L));
}

public void testHandlePartialResponses() {
ShardFollowTaskParams params = new ShardFollowTaskParams();
params.maxReadRequestOperationCount = 10;
params.maxOutstandingReadRequests = 2;
params.maxOutstandingWriteRequests = 1;
params.maxWriteBufferCount = 3;

ShardFollowNodeTask task = createShardFollowTask(params);
startTask(task, 99, -1);

task.coordinateReads();
assertThat(shardChangesRequests.size(), equalTo(2));
assertThat(shardChangesRequests.get(0)[0], equalTo(0L));
assertThat(shardChangesRequests.get(0)[1], equalTo(10L));
assertThat(shardChangesRequests.get(1)[0], equalTo(10L));
assertThat(shardChangesRequests.get(1)[1], equalTo(10L));

task.innerHandleReadResponse(0L, 9L, generateShardChangesResponse(0L, 5L, 0L, 0L, 99L));
assertThat(pendingBulkShardRequests, hasSize(1));
assertThat("continue the partial request", shardChangesRequests, hasSize(3));
assertThat(shardChangesRequests.get(2)[0], equalTo(6L));
assertThat(shardChangesRequests.get(2)[1], equalTo(4L));
assertThat(pendingBulkShardRequests, hasSize(1));
task.innerHandleReadResponse(10, 19L, generateShardChangesResponse(10L, 17L, 0L, 0L, 99L));
assertThat("do not continue partial reads as the buffer is full", shardChangesRequests, hasSize(3));
task.innerHandleReadResponse(6L, 9L, generateShardChangesResponse(6L, 8L, 0L, 0L, 99L));
assertThat("do not continue partial reads as the buffer is full", shardChangesRequests, hasSize(3));
pendingBulkShardRequests.remove().onResponse(new BulkShardOperationsResponse());
assertThat(pendingBulkShardRequests, hasSize(1));

assertThat("continue two partial requests as the buffer is empty after sending", shardChangesRequests, hasSize(5));
assertThat(shardChangesRequests.get(3)[0], equalTo(9L));
assertThat(shardChangesRequests.get(3)[1], equalTo(1L));
assertThat(shardChangesRequests.get(4)[0], equalTo(18L));
assertThat(shardChangesRequests.get(4)[1], equalTo(2L));

task.innerHandleReadResponse(18L, 19L, generateShardChangesResponse(18L, 19L, 0L, 0L, 99L));
assertThat("start new range as the buffer has empty slots", shardChangesRequests, hasSize(6));
assertThat(shardChangesRequests.get(5)[0], equalTo(20L));
assertThat(shardChangesRequests.get(5)[1], equalTo(10L));

task.innerHandleReadResponse(9L, 9L, generateShardChangesResponse(9L, 9L, 0L, 0L, 99L));
assertThat("do not start new range as the buffer is full", shardChangesRequests, hasSize(6));
pendingBulkShardRequests.remove().onResponse(new BulkShardOperationsResponse());
assertThat("start new range as the buffer is empty after sending", shardChangesRequests, hasSize(7));
assertThat(shardChangesRequests.get(6)[0], equalTo(30L));
assertThat(shardChangesRequests.get(6)[1], equalTo(10L));
}

public void testMappingUpdate() {
ShardFollowTaskParams params = new ShardFollowTaskParams();
params.maxReadRequestOperationCount = 64;
Expand Down Expand Up @@ -909,7 +960,7 @@ public void testMaxWriteRequestSize() {

ShardChangesAction.Response response = generateShardChangesResponse(0, 63, 0L, 0L, 64L);
// Also invokes coordinatesWrites()
task.innerHandleReadResponse(0L, 64L, response);
task.innerHandleReadResponse(0L, 63L, response);

assertThat(bulkShardOperationRequests.size(), equalTo(64));
}
Expand Down Expand Up @@ -1033,6 +1084,7 @@ private ShardFollowNodeTask createShardFollowTask(ShardFollowTaskParams params)
followerGlobalCheckpoints = new LinkedList<>();
maxSeqNos = new LinkedList<>();
responseSizes = new LinkedList<>();
pendingBulkShardRequests = new LinkedList<>();
return new ShardFollowNodeTask(
1L, "type", ShardFollowTask.NAME, "description", null, Collections.emptyMap(), followTask, scheduler, System::nanoTime) {

Expand Down Expand Up @@ -1082,6 +1134,8 @@ protected void innerSendBulkShardOperationsRequest(
response.setGlobalCheckpoint(followerGlobalCheckpoint);
response.setMaxSeqNo(followerGlobalCheckpoint);
handler.accept(response);
} else {
pendingBulkShardRequests.add(ActionListener.wrap(handler::accept, errorHandler));
}
}

Expand Down

0 comments on commit 138e3c7

Please sign in to comment.