Skip to content

Commit

Permalink
Harden exchange source handler (#104003)
Browse files Browse the repository at this point in the history
There is a bug introduced when we added support for cross clusters query 
in ESQL. If there is a long pause before the ComputeService sends
cluster requests to remote clusters, we risk finishing the exchange
source without linking remote sinks of the remote clusters. I can
reproduce this by inserting a long pause, but I couldn't write a useful
test unless we add a tracer and provide callbacks so that we can pause
in tests. This PR hardens and simplifies the exchange source handler. I
will review and simplify the exchange service as the distributed
execution has evolved, but we haven't simplified the exchange service
much accordingly.

Closes #103747
Closes #103749
  • Loading branch information
dnhatn committed Jan 7, 2024
1 parent a6c642d commit 9f7e1f5
Show file tree
Hide file tree
Showing 4 changed files with 40 additions and 78 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.common.util.concurrent.AbstractRunnable;
import org.elasticsearch.compute.data.Page;
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.tasks.TaskCancelledException;

import java.util.concurrent.Executor;
Expand All @@ -27,18 +27,19 @@
* @see #createExchangeSource()
* @see #addRemoteSink(RemoteSink, int)
*/
public final class ExchangeSourceHandler extends AbstractRefCounted {
public final class ExchangeSourceHandler {
private final ExchangeBuffer buffer;
private final Executor fetchExecutor;

private final PendingInstances outstandingSinks = new PendingInstances();
private final PendingInstances outstandingSources = new PendingInstances();
private final PendingInstances outstandingSinks;
private final PendingInstances outstandingSources;
private final AtomicReference<Exception> failure = new AtomicReference<>();
private final SubscribableListener<Void> completionFuture = new SubscribableListener<>();

public ExchangeSourceHandler(int maxBufferSize, Executor fetchExecutor) {
this.buffer = new ExchangeBuffer(maxBufferSize);
this.fetchExecutor = fetchExecutor;
this.outstandingSinks = new PendingInstances(() -> buffer.finish(false));
this.outstandingSources = new PendingInstances(() -> buffer.finish(true));
}

private class LocalExchangeSource implements ExchangeSource {
Expand Down Expand Up @@ -76,9 +77,7 @@ public SubscribableListener<Void> waitForReading() {
public void finish() {
if (finished == false) {
finished = true;
if (outstandingSources.finishInstance()) {
buffer.finish(true);
}
outstandingSources.finishInstance();
}
}

Expand Down Expand Up @@ -205,9 +204,7 @@ void onSinkFailed(Exception e) {
void onSinkComplete() {
if (finished == false) {
finished = true;
if (outstandingSinks.finishInstance()) {
buffer.finish(false);
}
outstandingSinks.finishInstance();
}
}
}
Expand Down Expand Up @@ -237,35 +234,36 @@ protected void doRun() {
}
}

@Override
protected void closeInternal() {
Exception error = failure.get();
if (error != null) {
completionFuture.onFailure(error);
} else {
completionFuture.onResponse(null);
}
}

/**
* Add a listener, which will be notified when this exchange source handler is completed. An exchange source
* handler is consider completed when all exchange factories and sinks are completed and de-attached.
* Links this exchange source with an empty/dummy remote sink. The purpose of this is to prevent this exchange source from finishing
* until we have performed other async actions, such as linking actual remote sinks.
*
* @return a Releasable that should be called when the caller no longer needs to prevent the exchange source from completing.
*/
public void addCompletionListener(ActionListener<Void> listener) {
completionFuture.addListener(listener);
public Releasable addEmptySink() {
outstandingSinks.trackNewInstance();
return outstandingSinks::finishInstance;
}

private final class PendingInstances {
private static class PendingInstances {
private final AtomicInteger instances = new AtomicInteger();
private final Releasable onComplete;

PendingInstances(Releasable onComplete) {
this.onComplete = onComplete;
}

void trackNewInstance() {
incRef();
instances.incrementAndGet();
int refs = instances.incrementAndGet();
assert refs > 0;
}

boolean finishInstance() {
decRef();
return instances.decrementAndGet() == 0;
void finishInstance() {
int refs = instances.decrementAndGet();
assert refs >= 0;
if (refs == 0) {
onComplete.close();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -54,8 +54,6 @@
import java.util.Collections;
import java.util.List;
import java.util.Set;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.function.Function;
import java.util.function.Supplier;
Expand Down Expand Up @@ -95,11 +93,8 @@ public void testBasic() throws Exception {
ExchangeSink sink1 = sinkExchanger.createExchangeSink();
ExchangeSink sink2 = sinkExchanger.createExchangeSink();
ExchangeSourceHandler sourceExchanger = new ExchangeSourceHandler(3, threadPool.executor(ESQL_TEST_EXECUTOR));
assertThat(sourceExchanger.refCount(), equalTo(1));
ExchangeSource source = sourceExchanger.createExchangeSource();
assertThat(sourceExchanger.refCount(), equalTo(2));
sourceExchanger.addRemoteSink(sinkExchanger::fetchPageAsync, 1);
assertThat(sourceExchanger.refCount(), equalTo(3));
SubscribableListener<Void> waitForReading = source.waitForReading();
assertFalse(waitForReading.isDone());
assertNull(source.pollPage());
Expand Down Expand Up @@ -137,13 +132,7 @@ public void testBasic() throws Exception {
sink2.finish();
assertTrue(sink2.isFinished());
assertTrue(source.isFinished());
assertBusy(() -> assertThat(sourceExchanger.refCount(), equalTo(2)));
source.finish();
assertThat(sourceExchanger.refCount(), equalTo(1));
CountDownLatch latch = new CountDownLatch(1);
sourceExchanger.addCompletionListener(ActionListener.releasing(latch::countDown));
sourceExchanger.decRef();
assertTrue(latch.await(1, TimeUnit.SECONDS));
ESTestCase.terminate(threadPool);
for (Page page : pages) {
page.releaseBlocks();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,6 @@
import org.elasticsearch.action.support.ContextPreservingActionListener;
import org.elasticsearch.action.support.RefCountingListener;
import org.elasticsearch.action.support.RefCountingRunnable;
import org.elasticsearch.action.support.SubscribableListener;
import org.elasticsearch.cluster.node.DiscoveryNode;
import org.elasticsearch.cluster.service.ClusterService;
import org.elasticsearch.common.util.BigArrays;
Expand All @@ -31,11 +30,9 @@
import org.elasticsearch.compute.operator.DriverProfile;
import org.elasticsearch.compute.operator.DriverTaskRunner;
import org.elasticsearch.compute.operator.ResponseHeadersCollector;
import org.elasticsearch.compute.operator.exchange.ExchangeResponse;
import org.elasticsearch.compute.operator.exchange.ExchangeService;
import org.elasticsearch.compute.operator.exchange.ExchangeSinkHandler;
import org.elasticsearch.compute.operator.exchange.ExchangeSourceHandler;
import org.elasticsearch.compute.operator.exchange.RemoteSink;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.Releasables;
Expand Down Expand Up @@ -184,11 +181,9 @@ public void execute(
transportService.getThreadPool().executor(ESQL_THREAD_POOL_NAME)
);
try (
Releasable ignored = exchangeSource::decRef;
Releasable ignored = exchangeSource.addEmptySink();
RefCountingListener refs = new RefCountingListener(listener.map(unused -> new Result(collectedPages, collectedProfiles)))
) {
// wait until the source handler is completed
exchangeSource.addCompletionListener(refs.acquire());
// run compute on the coordinator
runCompute(
rootTask,
Expand All @@ -213,6 +208,7 @@ public void execute(
Set.of(localConcreteIndices.indices()),
localOriginalIndices.indices(),
exchangeSource,
ActionListener.releaseAfter(refs.acquire(), exchangeSource.addEmptySink()),
() -> cancelOnFailure(rootTask, cancelled, refs.acquire()).map(response -> {
responseHeadersCollector.collect();
if (configuration.profile()) {
Expand Down Expand Up @@ -263,19 +259,6 @@ private List<RemoteCluster> getRemoteClusters(
return remoteClusters;
}

static final class EmptyRemoteSink implements RemoteSink {
final SubscribableListener<Void> future = new SubscribableListener<>();

@Override
public void fetchPageAsync(boolean allSourcesFinished, ActionListener<ExchangeResponse> listener) {
future.addListener(listener.map(ignored -> new ExchangeResponse(null, true)));
}

void finish() {
future.onResponse(null);
}
}

private void startComputeOnDataNodes(
String sessionId,
String clusterAlias,
Expand All @@ -285,18 +268,16 @@ private void startComputeOnDataNodes(
Set<String> concreteIndices,
String[] originalIndices,
ExchangeSourceHandler exchangeSource,
Supplier<ActionListener<ComputeResponse>> listener
ActionListener<Void> parentListener,
Supplier<ActionListener<ComputeResponse>> dataNodeListenerSupplier
) {
// Do not complete the exchange sources until we have linked all remote sinks
final EmptyRemoteSink emptyRemoteSink = new EmptyRemoteSink();
exchangeSource.addRemoteSink(emptyRemoteSink, 1);
QueryBuilder requestFilter = PlannerUtils.requestFilter(dataNodePlan);
lookupDataNodes(parentTask, clusterAlias, requestFilter, concreteIndices, originalIndices, ActionListener.wrap(dataNodes -> {
try (RefCountingRunnable refs = new RefCountingRunnable(emptyRemoteSink::finish)) {
try (RefCountingRunnable refs = new RefCountingRunnable(() -> parentListener.onResponse(null))) {
// For each target node, first open a remote exchange on the remote node, then link the exchange source to
// the new remote exchange sink, and initialize the computation on the target node via data-node-request.
for (DataNode node : dataNodes) {
var dataNodeListener = ActionListener.releaseAfter(listener.get(), refs.acquire());
var dataNodeListener = ActionListener.releaseAfter(dataNodeListenerSupplier.get(), refs.acquire());
var queryPragmas = configuration.pragmas();
ExchangeService.openExchange(
transportService,
Expand All @@ -319,10 +300,7 @@ private void startComputeOnDataNodes(
);
}
}
}, e -> {
emptyRemoteSink.finish();
listener.get().onFailure(e);
}));
}, parentListener::onFailure));
}

private void startComputeOnRemoteClusters(
Expand All @@ -334,10 +312,7 @@ private void startComputeOnRemoteClusters(
List<RemoteCluster> clusters,
Supplier<ActionListener<ComputeResponse>> listener
) {
// Do not complete the exchange sources until we have linked all remote sinks
final EmptyRemoteSink emptyRemoteSink = new EmptyRemoteSink();
exchangeSource.addRemoteSink(emptyRemoteSink, 1);
try (RefCountingRunnable refs = new RefCountingRunnable(emptyRemoteSink::finish)) {
try (RefCountingRunnable refs = new RefCountingRunnable(exchangeSource.addEmptySink()::close)) {
for (RemoteCluster cluster : clusters) {
var targetNodeListener = ActionListener.releaseAfter(listener.get(), refs.acquire());
var queryPragmas = configuration.pragmas();
Expand Down Expand Up @@ -667,10 +642,9 @@ void runComputeOnRemoteCluster(
transportService.getThreadPool().executor(ESQL_THREAD_POOL_NAME)
);
try (
Releasable ignored = exchangeSource::decRef;
Releasable ignored = exchangeSource.addEmptySink();
RefCountingListener refs = new RefCountingListener(listener.map(unused -> new ComputeResponse(collectedProfiles)))
) {
exchangeSource.addCompletionListener(refs.acquire());
exchangeSink.addCompletionListener(refs.acquire());
PhysicalPlan coordinatorPlan = new ExchangeSinkExec(
plan.source(),
Expand Down Expand Up @@ -699,6 +673,7 @@ void runComputeOnRemoteCluster(
concreteIndices,
originalIndices,
exchangeSource,
ActionListener.releaseAfter(refs.acquire(), exchangeSource.addEmptySink()),
() -> cancelOnFailure(parentTask, cancelled, refs.acquire()).map(r -> {
responseHeadersCollector.collect();
if (configuration.profile()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,7 +412,7 @@ protected void start(Driver driver, ActionListener<Void> driverListener) {
}));
return future.actionGet(TimeValue.timeValueSeconds(30));
} finally {
Releasables.close(() -> Releasables.close(drivers), exchangeSource::decRef);
Releasables.close(() -> Releasables.close(drivers));
}
}

Expand Down

0 comments on commit 9f7e1f5

Please sign in to comment.