From 17edd33b3579e6d75e403fc6adccd2aa026b223a Mon Sep 17 00:00:00 2001 From: Nhat Nguyen Date: Wed, 8 May 2024 14:32:05 -0700 Subject: [PATCH] Exchange should wait for remote sinks (#108337) (#108430) Today, we do not wait for remote sinks to stop before completing the main request. While this doesn't affect correctness, it's important that we do not spawn child requests after the parent request is completed. Closes #105859 --- .../exchange/ExchangeSourceHandler.java | 24 +++++++++++++++---- .../exchange/ExchangeServiceTests.java | 16 ++++++++++++- .../xpack/esql/plugin/ComputeService.java | 3 +++ 3 files changed, 38 insertions(+), 5 deletions(-) diff --git a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java index f1698ea401d28..adce8d8a88407 100644 --- a/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java +++ b/x-pack/plugin/esql/compute/src/main/java/org/elasticsearch/compute/operator/exchange/ExchangeSourceHandler.java @@ -10,6 +10,7 @@ import org.elasticsearch.ElasticsearchException; import org.elasticsearch.ExceptionsHelper; import org.elasticsearch.action.ActionListener; +import org.elasticsearch.action.support.RefCountingListener; import org.elasticsearch.action.support.SubscribableListener; import org.elasticsearch.common.util.concurrent.AbstractRunnable; import org.elasticsearch.compute.data.Page; @@ -17,6 +18,7 @@ import org.elasticsearch.tasks.TaskCancelledException; import org.elasticsearch.transport.TransportException; +import java.util.List; import java.util.concurrent.Executor; import java.util.concurrent.atomic.AtomicInteger; import java.util.concurrent.atomic.AtomicReference; @@ -89,6 +91,20 @@ public int bufferSize() { } } + public void addCompletionListener(ActionListener listener) { + buffer.addCompletionListener(ActionListener.running(() -> { + try (RefCountingListener refs = new RefCountingListener(listener)) { + for (PendingInstances pending : List.of(outstandingSinks, outstandingSources)) { + // Create an outstanding instance and then finish to complete the completionListener + // if we haven't registered any instances of exchange sinks or exchange sources before. + pending.trackNewInstance(); + pending.completion.addListener(refs.acquire()); + pending.finishInstance(); + } + } + })); + } + /** * Create a new {@link ExchangeSource} for exchanging data * @@ -253,10 +269,10 @@ public Releasable addEmptySink() { private static class PendingInstances { private final AtomicInteger instances = new AtomicInteger(); - private final Releasable onComplete; + private final SubscribableListener completion = new SubscribableListener<>(); - PendingInstances(Releasable onComplete) { - this.onComplete = onComplete; + PendingInstances(Runnable onComplete) { + completion.addListener(ActionListener.running(onComplete)); } void trackNewInstance() { @@ -268,7 +284,7 @@ void finishInstance() { int refs = instances.decrementAndGet(); assert refs >= 0; if (refs == 0) { - onComplete.close(); + completion.onResponse(null); } } } diff --git a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java index bdaa045633dc0..51332b3c8997a 100644 --- a/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java +++ b/x-pack/plugin/esql/compute/src/test/java/org/elasticsearch/compute/operator/exchange/ExchangeServiceTests.java @@ -55,6 +55,7 @@ import java.util.Collections; import java.util.List; import java.util.Set; +import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicInteger; import java.util.function.Function; import java.util.function.Supplier; @@ -94,6 +95,8 @@ public void testBasic() throws Exception { ExchangeSink sink1 = sinkExchanger.createExchangeSink(); ExchangeSink sink2 = sinkExchanger.createExchangeSink(); ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler(3, threadPool.executor(ESQL_TEST_EXECUTOR)); + PlainActionFuture sourceCompletion = new PlainActionFuture<>(); + sourceExchanger.addCompletionListener(sourceCompletion); ExchangeSource source = sourceExchanger.createExchangeSource(); sourceExchanger.addRemoteSink(sinkExchanger::fetchPageAsync, 1); SubscribableListener waitForReading = source.waitForReading(); @@ -133,7 +136,9 @@ public void testBasic() throws Exception { sink2.finish(); assertTrue(sink2.isFinished()); assertTrue(source.isFinished()); + assertFalse(sourceCompletion.isDone()); source.finish(); + sourceCompletion.actionGet(10, TimeUnit.SECONDS); ESTestCase.terminate(threadPool); for (Page page : pages) { page.releaseBlocks(); @@ -320,7 +325,9 @@ protected void start(Driver driver, ActionListener listener) { public void testConcurrentWithHandlers() { BlockFactory blockFactory = blockFactory(); + PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); var sourceExchanger = new ExchangeSourceHandler(randomExchangeBuffer(), threadPool.executor(ESQL_TEST_EXECUTOR)); + sourceExchanger.addCompletionListener(sourceCompletionFuture); List sinkHandlers = new ArrayList<>(); Supplier exchangeSink = () -> { final ExchangeSinkHandler sinkHandler; @@ -336,6 +343,7 @@ public void testConcurrentWithHandlers() { final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceExchanger::createExchangeSource, exchangeSink); + sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS); } public void testEarlyTerminate() { @@ -358,7 +366,7 @@ public void testEarlyTerminate() { assertTrue(sink.isFinished()); } - public void testConcurrentWithTransportActions() throws Exception { + public void testConcurrentWithTransportActions() { MockTransportService node0 = newTransportService(); ExchangeService exchange0 = new ExchangeService(Settings.EMPTY, threadPool, ESQL_TEST_EXECUTOR, blockFactory()); exchange0.registerTransportHandler(node0); @@ -371,12 +379,15 @@ public void testConcurrentWithTransportActions() throws Exception { String exchangeId = "exchange"; Task task = new Task(1, "", "", "", null, Collections.emptyMap()); var sourceHandler = new ExchangeSourceHandler(randomExchangeBuffer(), threadPool.executor(ESQL_TEST_EXECUTOR)); + PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); + sourceHandler.addCompletionListener(sourceCompletionFuture); ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomExchangeBuffer()); Transport.Connection connection = node0.getConnection(node1.getLocalNode()); sourceHandler.addRemoteSink(exchange0.newRemoteSink(task, exchangeId, node0, connection), randomIntBetween(1, 5)); final int maxInputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); final int maxOutputSeqNo = rarely() ? -1 : randomIntBetween(0, 50_000); runConcurrentTest(maxInputSeqNo, maxOutputSeqNo, sourceHandler::createExchangeSource, sinkHandler::createExchangeSink); + sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS); } } @@ -427,6 +438,8 @@ public void sendResponse(TransportResponse transportResponse) { String exchangeId = "exchange"; Task task = new Task(1, "", "", "", null, Collections.emptyMap()); var sourceHandler = new ExchangeSourceHandler(randomIntBetween(1, 128), threadPool.executor(ESQL_TEST_EXECUTOR)); + PlainActionFuture sourceCompletionFuture = new PlainActionFuture<>(); + sourceHandler.addCompletionListener(sourceCompletionFuture); ExchangeSinkHandler sinkHandler = exchange1.createSinkHandler(exchangeId, randomIntBetween(1, 128)); Transport.Connection connection = node0.getConnection(node1.getLocalDiscoNode()); sourceHandler.addRemoteSink(exchange0.newRemoteSink(task, exchangeId, node0, connection), randomIntBetween(1, 5)); @@ -438,6 +451,7 @@ public void sendResponse(TransportResponse transportResponse) { assertNotNull(cause); assertThat(cause.getMessage(), equalTo("page is too large")); sinkHandler.onFailure(new RuntimeException(cause)); + sourceCompletionFuture.actionGet(10, TimeUnit.SECONDS); } } diff --git a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java index 7b38197dde95a..d9005d5997b34 100644 --- a/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java +++ b/x-pack/plugin/esql/src/main/java/org/elasticsearch/xpack/esql/plugin/ComputeService.java @@ -205,6 +205,7 @@ public void execute( RefCountingListener refs = new RefCountingListener(listener.map(unused -> new Result(collectedPages, collectedProfiles))) ) { // run compute on the coordinator + exchangeSource.addCompletionListener(refs.acquire()); runCompute( rootTask, new ComputeContext(sessionId, RemoteClusterAware.LOCAL_CLUSTER_GROUP_KEY, List.of(), configuration, exchangeSource, null), @@ -722,6 +723,7 @@ private void runComputeOnDataNode( var externalSink = exchangeService.getSinkHandler(externalId); task.addListener(() -> exchangeService.finishSinkHandler(externalId, new TaskCancelledException(task.getReasonCancelled()))); var exchangeSource = new ExchangeSourceHandler(1, esqlExecutor); + exchangeSource.addCompletionListener(refs.acquire()); exchangeSource.addRemoteSink(internalSink::fetchPageAsync, 1); ActionListener reductionListener = cancelOnFailure(task, cancelled, refs.acquire()); runCompute( @@ -854,6 +856,7 @@ void runComputeOnRemoteCluster( RefCountingListener refs = new RefCountingListener(listener.map(unused -> new ComputeResponse(collectedProfiles))) ) { exchangeSink.addCompletionListener(refs.acquire()); + exchangeSource.addCompletionListener(refs.acquire()); PhysicalPlan coordinatorPlan = new ExchangeSinkExec( plan.source(), plan.output(),