Skip to content

Commit

Permalink
Make TransportChannel#sendResponse refcount-neutral (#104044)
Browse files Browse the repository at this point in the history
Refcount-neutral methods are much easier on the reader because they mean
we can make obvious pairings between `incRef()` and `decRef()` calls.
This commit removes the implicit `decRef()` invocation from
`TransportChannel#sendResponse` to make it refcount-neutral.
  • Loading branch information
DaveCTurner committed Jan 22, 2024
1 parent 48ff111 commit 72e0c1f
Show file tree
Hide file tree
Showing 14 changed files with 152 additions and 77 deletions.
Expand Up @@ -28,7 +28,6 @@ public ChannelActionListener(TransportChannel channel) {

@Override
public void onResponse(Response response) {
response.incRef(); // acquire reference that will be released by channel.sendResponse below
ActionListener.run(this, l -> l.channel.sendResponse(response));
}

Expand Down
Expand Up @@ -230,7 +230,12 @@ protected void resolveRequest(NodesRequest request, ClusterState clusterState) {
class NodeTransportHandler implements TransportRequestHandler<NodeRequest> {
@Override
public void messageReceived(NodeRequest request, TransportChannel channel, Task task) throws Exception {
channel.sendResponse(nodeOperation(request, task));
final var nodeResponse = nodeOperation(request, task);
try {
channel.sendResponse(nodeResponse);
} finally {
nodeResponse.decRef();
}
}
}

Expand Down
Expand Up @@ -136,6 +136,7 @@ void sendResponse(
isHandshake,
compressionScheme
);
response.mustIncRef();
sendMessage(channel, message, responseStatsConsumer, () -> {
try {
messageListener.onResponseSent(requestId, action, response);
Expand Down
Expand Up @@ -65,7 +65,6 @@ public Executor executor(ThreadPool threadPool) {
@Override
public void handleResponse(TransportResponse response) {
try {
response.mustIncRef();
channel.sendResponse(response);
} catch (IOException e) {
throw new UncheckedIOException(e);
Expand Down
Expand Up @@ -1482,43 +1482,39 @@ public String getProfileName() {

@Override
public void sendResponse(TransportResponse response) throws IOException {
try {
service.onResponseSent(requestId, action, response);
try (var shutdownBlock = service.pendingDirectHandlers.withRef()) {
if (shutdownBlock == null) {
// already shutting down, the handler will be completed by sendRequestInternal or doStop
return;
}
final TransportResponseHandler<?> handler = service.responseHandlers.onResponseReceived(requestId, service);
if (handler == null) {
// handler already completed, likely by a timeout which is logged elsewhere
return;
}
final var executor = handler.executor(threadPool);
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
processResponse(handler, response);
} else {
response.mustIncRef();
executor.execute(new ForkingResponseHandlerRunnable(handler, null, threadPool) {
@Override
protected void doRun() {
processResponse(handler, response);
}
service.onResponseSent(requestId, action, response);
try (var shutdownBlock = service.pendingDirectHandlers.withRef()) {
if (shutdownBlock == null) {
// already shutting down, the handler will be completed by sendRequestInternal or doStop
return;
}
final TransportResponseHandler<?> handler = service.responseHandlers.onResponseReceived(requestId, service);
if (handler == null) {
// handler already completed, likely by a timeout which is logged elsewhere
return;
}
final var executor = handler.executor(threadPool);
if (executor == EsExecutors.DIRECT_EXECUTOR_SERVICE) {
processResponse(handler, response);
} else {
response.mustIncRef();
executor.execute(new ForkingResponseHandlerRunnable(handler, null, threadPool) {
@Override
protected void doRun() {
processResponse(handler, response);
}

@Override
public void onAfter() {
response.decRef();
}
@Override
public void onAfter() {
response.decRef();
}

@Override
public String toString() {
return "delivery of response to [" + requestId + "][" + action + "]: " + response;
}
});
}
@Override
public String toString() {
return "delivery of response to [" + requestId + "][" + action + "]: " + response;
}
});
}
} finally {
response.decRef();
}
}

Expand Down
Expand Up @@ -139,6 +139,7 @@ public void testResponseAggregation() {
successfulNodes.add(capturedRequest.node());
final var response = new TestNodeResponse(capturedRequest.node());
transport.handleResponse(capturedRequest.requestId(), response);
response.decRef();
assertFalse(response.hasReferences()); // response is copied (via the wire protocol) so this instance is released
} else {
failedNodeIds.add(capturedRequest.node().getId());
Expand Down
Expand Up @@ -38,6 +38,7 @@
import org.elasticsearch.common.util.concurrent.EsExecutors;
import org.elasticsearch.common.xcontent.XContentHelper;
import org.elasticsearch.core.IOUtils;
import org.elasticsearch.core.ReleasableRef;
import org.elasticsearch.core.SuppressForbidden;
import org.elasticsearch.core.TimeValue;
import org.elasticsearch.index.IndexNotFoundException;
Expand Down Expand Up @@ -155,23 +156,28 @@ public static MockTransportService startTransport(
} else {
searchHits = SearchHits.empty(new TotalHits(0, TotalHits.Relation.EQUAL_TO), Float.NaN);
}
SearchResponse searchResponse = new SearchResponse(
searchHits,
InternalAggregations.EMPTY,
null,
false,
null,
null,
1,
null,
1,
1,
0,
100,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
);
channel.sendResponse(searchResponse);
try (
var searchResponseRef = ReleasableRef.of(
new SearchResponse(
searchHits,
InternalAggregations.EMPTY,
null,
false,
null,
null,
1,
null,
1,
1,
0,
100,
ShardSearchFailure.EMPTY_ARRAY,
SearchResponse.Clusters.EMPTY
)
)
) {
channel.sendResponse(searchResponseRef.get());
}
}
);
newService.registerRequestHandler(
Expand Down
Expand Up @@ -118,6 +118,7 @@ public void testSendMessage() throws InterruptedException {
assertEquals(request.sourceNode, "TS_A");
final SimpleTestResponse response = new SimpleTestResponse("TS_A");
channel.sendResponse(response);
response.decRef();
assertThat(response.hasReferences(), equalTo(false));
}
);
Expand All @@ -134,6 +135,7 @@ public void testSendMessage() throws InterruptedException {
assertEquals(request.sourceNode, "TS_A");
final SimpleTestResponse response = new SimpleTestResponse("TS_B");
channel.sendResponse(response);
response.decRef();
assertThat(response.hasReferences(), equalTo(false));
}
);
Expand All @@ -148,6 +150,7 @@ public void testSendMessage() throws InterruptedException {
assertEquals(request.sourceNode, "TS_A");
final SimpleTestResponse response = new SimpleTestResponse("TS_C");
channel.sendResponse(response);
response.decRef();
assertThat(response.hasReferences(), equalTo(false));
}
);
Expand Down Expand Up @@ -319,8 +322,10 @@ public void handleException(TransportException exp) {
);
latch.await();

assertThat(response.get(), notNullValue());
assertBusy(() -> assertThat(response.get().hasReferences(), equalTo(false)));
final var responseInstance = response.get();
assertThat(responseInstance, notNullValue());
responseInstance.decRef();
assertBusy(() -> assertThat(responseInstance.hasReferences(), equalTo(false)));
}

public void testException() throws InterruptedException {
Expand Down Expand Up @@ -439,10 +444,10 @@ public boolean shouldCancelChildrenOnCancellation() {
public static class SimpleTestResponse extends TransportResponse {

final String targetNode;
final RefCounted refCounted = new AbstractRefCounted() {
final RefCounted refCounted = LeakTracker.wrap(new AbstractRefCounted() {
@Override
protected void closeInternal() {}
};
});

SimpleTestResponse(String targetNode) {
this.targetNode = targetNode;
Expand Down
@@ -0,0 +1,47 @@
/*
* Copyright Elasticsearch B.V. and/or licensed to Elasticsearch B.V. under one
* or more contributor license agreements. Licensed under the Elastic License
* 2.0 and the Server Side Public License, v 1; you may not use this file except
* in compliance with, at your election, the Elastic License 2.0 or the Server
* Side Public License, v 1.
*/

package org.elasticsearch.core;

import org.elasticsearch.transport.LeakTracker;

import java.util.Objects;

/**
* Adapter to use a {@link RefCounted} in a try-with-resources block.
*/
public final class ReleasableRef<T extends RefCounted> implements Releasable {

private final Releasable closeResource;
private final T resource;

private ReleasableRef(T resource) {
this.resource = Objects.requireNonNull(resource);
this.closeResource = LeakTracker.wrap(Releasables.assertOnce(resource::decRef));
}

@Override
public void close() {
closeResource.close();
}

public static <T extends RefCounted> ReleasableRef<T> of(T resource) {
return new ReleasableRef<>(resource);
}

public T get() {
assert resource.hasReferences() : resource + " is closed";
return resource;
}

@Override
public String toString() {
return "ReleasableRef[" + resource + ']';
}

}
Expand Up @@ -101,8 +101,6 @@ public <Response extends TransportResponse> void handleResponse(final long reque
);
} catch (IOException | UnsupportedOperationException e) {
throw new AssertionError("failed to serialize/deserialize response " + response, e);
} finally {
response.decRef();
}
try {
transportResponseHandler.handleResponse(deliveredResponse);
Expand Down
Expand Up @@ -20,6 +20,7 @@
import org.elasticsearch.common.transport.TransportAddress;
import org.elasticsearch.common.util.concurrent.DeterministicTaskQueue;
import org.elasticsearch.core.Nullable;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.tasks.TaskManager;
import org.elasticsearch.telemetry.tracing.Tracer;
import org.elasticsearch.test.transport.MockTransport;
Expand Down Expand Up @@ -263,24 +264,28 @@ public String getProfileName() {

@Override
public void sendResponse(final TransportResponse response) {
response.mustIncRef();
final var releasable = Releasables.assertOnce(response::decRef);
execute(new RebootSensitiveRunnable() {
@Override
public void ifRebooted() {
response.decRef();
cleanupResponseHandler(requestId);
try (releasable) {
cleanupResponseHandler(requestId);
}
}

@Override
public void run() {
final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode());
switch (connectionStatus) {
case CONNECTED, BLACK_HOLE_REQUESTS_ONLY -> handleResponse(requestId, response);
case BLACK_HOLE, DISCONNECTED -> {
response.decRef();
logger.trace("delaying response to {}: channel is {}", requestDescription, connectionStatus);
onBlackholedDuringSend(requestId, action, destinationTransport);
try (releasable) {
final ConnectionStatus connectionStatus = destinationTransport.getConnectionStatus(getLocalNode());
switch (connectionStatus) {
case CONNECTED, BLACK_HOLE_REQUESTS_ONLY -> handleResponse(requestId, response);
case BLACK_HOLE, DISCONNECTED -> {
logger.trace("delaying response to {}: channel is {}", requestDescription, connectionStatus);
onBlackholedDuringSend(requestId, action, destinationTransport);
}
default -> throw new AssertionError("unexpected status: " + connectionStatus);
}
default -> throw new AssertionError("unexpected status: " + connectionStatus);
}
}

Expand Down
Expand Up @@ -23,6 +23,7 @@
import org.elasticsearch.core.AbstractRefCounted;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasable;
import org.elasticsearch.core.ReleasableRef;
import org.elasticsearch.core.Tuple;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.threadpool.ThreadPool;
Expand Down Expand Up @@ -229,7 +230,9 @@ private TransportRequestHandler<TestRequest> requestHandlerShouldNotBeCalled() {
private TransportRequestHandler<TestRequest> requestHandlerRepliesNormally() {
return (request, channel, task) -> {
logger.debug("got a dummy request, replying normally...");
channel.sendResponse(new TestResponse());
try (var responseRef = ReleasableRef.of(new TestResponse())) {
channel.sendResponse(responseRef.get());
}
};
}

Expand Down Expand Up @@ -415,7 +418,9 @@ public void testDisconnectedOnSuccessfulResponse() throws IOException {
assertNull(responseHandlerException.get());

disconnectedLinks.add(Tuple.tuple(node2, node1));
responseHandlerChannel.get().sendResponse(new TestResponse());
try (var responseRef = ReleasableRef.of(new TestResponse())) {
responseHandlerChannel.get().sendResponse(responseRef.get());
}
deterministicTaskQueue.runAllTasks();
deliverBlackholedRequests.run();
deterministicTaskQueue.runAllTasks();
Expand Down Expand Up @@ -453,7 +458,9 @@ public void testUnavailableOnSuccessfulResponse() throws IOException {
assertNotNull(responseHandlerChannel.get());

blackholedLinks.add(Tuple.tuple(node2, node1));
responseHandlerChannel.get().sendResponse(new TestResponse());
try (var responseRef = ReleasableRef.of(new TestResponse())) {
responseHandlerChannel.get().sendResponse(responseRef.get());
}
deterministicTaskQueue.runAllRunnableTasks();
}

Expand Down Expand Up @@ -485,7 +492,9 @@ public void testUnavailableOnRequestOnlyReceivesSuccessfulResponse() throws IOEx

blackholedRequestLinks.add(Tuple.tuple(node1, node2));
blackholedRequestLinks.add(Tuple.tuple(node2, node1));
responseHandlerChannel.get().sendResponse(new TestResponse());
try (var responseRef = ReleasableRef.of(new TestResponse())) {
responseHandlerChannel.get().sendResponse(responseRef.get());
}

deterministicTaskQueue.runAllRunnableTasks();
assertTrue(responseHandlerCalled.get());
Expand Down

0 comments on commit 72e0c1f

Please sign in to comment.