Skip to content

Commit

Permalink
Misc cleanups around Netty4HttpPipeliningHandler (#104642)
Browse files Browse the repository at this point in the history
- Make `Netty4HttpResponse` sealed, we need to know all its impls
- Rename `Netty4FullHttpResponse` to contrast with chunked response
- Rename `doWrite` overloads
- Reorder `finishChunkedWrite()`, we're ready to handle the next
  response before completing the promise
- Add missing test for splitting large responses
- Add a few extra assertions
- Clean up IDE warnings
  • Loading branch information
DaveCTurner committed Jan 24, 2024
1 parent d6f900c commit cf67f5d
Show file tree
Hide file tree
Showing 9 changed files with 104 additions and 50 deletions.
Expand Up @@ -16,11 +16,11 @@
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.transport.netty4.Netty4Utils;

public class Netty4HttpResponse extends DefaultFullHttpResponse implements Netty4RestResponse {
public final class Netty4FullHttpResponse extends DefaultFullHttpResponse implements Netty4RestResponse {

private final int sequence;

Netty4HttpResponse(int sequence, HttpVersion version, RestStatus status, BytesReference content) {
Netty4FullHttpResponse(int sequence, HttpVersion version, RestStatus status, BytesReference content) {
super(version, HttpResponseStatus.valueOf(status.getStatus()), Netty4Utils.toByteBuf(content));
this.sequence = sequence;
}
Expand Down
Expand Up @@ -57,7 +57,7 @@ public class Netty4HttpPipeliningHandler extends ChannelDuplexHandler {
private final int maxEventsHeld;
private final PriorityQueue<Tuple<? extends Netty4RestResponse, ChannelPromise>> outboundHoldingQueue;

private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, Netty4ChunkedHttpResponse response) {}
private record ChunkedWrite(PromiseCombiner combiner, ChannelPromise onDone, ChunkedRestResponseBody responseBody) {}

/**
* The current {@link ChunkedWrite} if a chunked write is executed at the moment.
Expand Down Expand Up @@ -150,6 +150,8 @@ public void write(final ChannelHandlerContext ctx, final Object msg, final Chann
);
}
// response is not at the current sequence number so we add it to the outbound queue and return
assert outboundHoldingQueue.stream().noneMatch(t -> t.v1().getSequence() == writeSequence)
: "duplicate outbound entries for seqno " + writeSequence;
outboundHoldingQueue.add(new Tuple<>(restResponse, promise));
success = true;
return;
Expand Down Expand Up @@ -191,17 +193,22 @@ private void doWriteQueued(ChannelHandlerContext ctx) throws IOException {

private void doWrite(ChannelHandlerContext ctx, Netty4RestResponse readyResponse, ChannelPromise promise) throws IOException {
assert currentChunkedWrite == null : "unexpected existing write [" + currentChunkedWrite + "]";
if (readyResponse instanceof Netty4HttpResponse) {
doWrite(ctx, (Netty4HttpResponse) readyResponse, promise);
assert readyResponse != null : "cannot write null response";
assert readyResponse.getSequence() == writeSequence;
if (readyResponse instanceof Netty4FullHttpResponse fullResponse) {
doWriteFullResponse(ctx, fullResponse, promise);
} else if (readyResponse instanceof Netty4ChunkedHttpResponse chunkedResponse) {
doWriteChunkedResponse(ctx, chunkedResponse, promise);
} else {
doWrite(ctx, (Netty4ChunkedHttpResponse) readyResponse, promise);
assert false : readyResponse.getClass().getCanonicalName();
throw new IllegalStateException("illegal message type: " + readyResponse.getClass().getCanonicalName());
}
}

/**
* Split up large responses to prevent batch compression {@link JdkZlibEncoder} down the pipeline.
*/
private void doWrite(ChannelHandlerContext ctx, Netty4HttpResponse readyResponse, ChannelPromise promise) {
private void doWriteFullResponse(ChannelHandlerContext ctx, Netty4FullHttpResponse readyResponse, ChannelPromise promise) {
if (DO_NOT_SPLIT_HTTP_RESPONSES || readyResponse.content().readableBytes() <= SPLIT_THRESHOLD) {
enqueueWrite(ctx, readyResponse, promise);
} else {
Expand All @@ -210,16 +217,19 @@ private void doWrite(ChannelHandlerContext ctx, Netty4HttpResponse readyResponse
writeSequence++;
}

private void doWrite(ChannelHandlerContext ctx, Netty4ChunkedHttpResponse readyResponse, ChannelPromise promise) throws IOException {
private void doWriteChunkedResponse(ChannelHandlerContext ctx, Netty4ChunkedHttpResponse readyResponse, ChannelPromise promise)
throws IOException {
final PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
final ChannelPromise first = ctx.newPromise();
combiner.add((Future<Void>) first);
currentChunkedWrite = new ChunkedWrite(combiner, promise, readyResponse);
final var responseBody = readyResponse.body();
assert currentChunkedWrite == null;
currentChunkedWrite = new ChunkedWrite(combiner, promise, responseBody);
if (enqueueWrite(ctx, readyResponse, first)) {
// We were able to write out the first chunk directly, try writing out subsequent chunks until the channel becomes unwritable.
// NB "writable" means there's space in the downstream ChannelOutboundBuffer, we aren't trying to saturate the physical channel.
while (ctx.channel().isWritable()) {
if (writeChunk(ctx, combiner, readyResponse.body())) {
if (writeChunk(ctx, combiner, responseBody)) {
finishChunkedWrite();
return;
}
Expand All @@ -228,15 +238,15 @@ private void doWrite(ChannelHandlerContext ctx, Netty4ChunkedHttpResponse readyR
}

private void finishChunkedWrite() {
try {
currentChunkedWrite.combiner.finish(currentChunkedWrite.onDone);
} finally {
currentChunkedWrite = null;
writeSequence++;
}
assert currentChunkedWrite != null;
assert currentChunkedWrite.responseBody().isDone();
final var finishingWrite = currentChunkedWrite;
currentChunkedWrite = null;
writeSequence++;
finishingWrite.combiner.finish(finishingWrite.onDone());
}

private void splitAndWrite(ChannelHandlerContext ctx, Netty4HttpResponse msg, ChannelPromise promise) {
private void splitAndWrite(ChannelHandlerContext ctx, Netty4FullHttpResponse msg, ChannelPromise promise) {
final PromiseCombiner combiner = new PromiseCombiner(ctx.executor());
HttpResponse response = new DefaultHttpResponse(msg.protocolVersion(), msg.status(), msg.headers());
combiner.add(enqueueWrite(ctx, response));
Expand Down Expand Up @@ -293,7 +303,7 @@ private boolean doFlush(ChannelHandlerContext ctx) throws IOException {
if (currentWrite == null) {
// no bytes were found queued, check if a chunked message might have become writable
if (currentChunkedWrite != null) {
if (writeChunk(ctx, currentChunkedWrite.combiner, currentChunkedWrite.response.body())) {
if (writeChunk(ctx, currentChunkedWrite.combiner, currentChunkedWrite.responseBody())) {
finishChunkedWrite();
}
continue;
Expand Down
Expand Up @@ -171,8 +171,8 @@ public HttpRequest removeHeader(String header) {
}

@Override
public Netty4HttpResponse createResponse(RestStatus status, BytesReference contentRef) {
return new Netty4HttpResponse(sequence, request.protocolVersion(), status, contentRef);
public Netty4FullHttpResponse createResponse(RestStatus status, BytesReference contentRef) {
return new Netty4FullHttpResponse(sequence, request.protocolVersion(), status, contentRef);
}

@Override
Expand Down
Expand Up @@ -419,8 +419,8 @@ protected HttpMessage createMessage(String[] initialLine) throws Exception {
protected boolean isContentAlwaysEmpty(HttpResponse msg) {
// non-chunked responses (Netty4HttpResponse extends Netty's DefaultFullHttpResponse) with chunked transfer
// encoding are only sent by us in response to HEAD requests and must always have an empty body
if (msg instanceof Netty4HttpResponse netty4HttpResponse && HttpUtil.isTransferEncodingChunked(msg)) {
assert netty4HttpResponse.content().isReadable() == false;
if (msg instanceof Netty4FullHttpResponse netty4FullHttpResponse && HttpUtil.isTransferEncodingChunked(msg)) {
assert netty4FullHttpResponse.content().isReadable() == false;
return true;
}
return super.isContentAlwaysEmpty(msg);
Expand Down
Expand Up @@ -12,7 +12,7 @@

import org.elasticsearch.http.HttpResponse;

public interface Netty4RestResponse extends HttpResponse, HttpMessage {
public sealed interface Netty4RestResponse extends HttpResponse, HttpMessage permits Netty4FullHttpResponse, Netty4ChunkedHttpResponse {

int getSequence();

Expand Down
Expand Up @@ -18,6 +18,8 @@
import io.netty.handler.codec.DecoderResult;
import io.netty.handler.codec.http.DefaultFullHttpRequest;
import io.netty.handler.codec.http.DefaultHttpContent;
import io.netty.handler.codec.http.DefaultHttpResponse;
import io.netty.handler.codec.http.DefaultLastHttpContent;
import io.netty.handler.codec.http.FullHttpResponse;
import io.netty.handler.codec.http.HttpMethod;
import io.netty.handler.codec.http.HttpVersion;
Expand All @@ -30,12 +32,14 @@
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.bytes.ZeroBytesReference;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.http.HttpResponse;
import org.elasticsearch.rest.ChunkedRestResponseBody;
import org.elasticsearch.rest.RestStatus;
import org.elasticsearch.test.ESTestCase;
import org.elasticsearch.transport.netty4.Netty4Utils;
import org.elasticsearch.transport.netty4.NettyAllocator;
import org.junit.After;

import java.nio.channels.ClosedChannelException;
Expand All @@ -56,6 +60,7 @@
import static org.hamcrest.Matchers.greaterThan;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.not;
import static org.hamcrest.Matchers.sameInstance;
import static org.hamcrest.core.Is.is;
import static org.mockito.Mockito.mock;
Expand All @@ -70,18 +75,18 @@ public class Netty4HttpPipeliningHandlerTests extends ESTestCase {
@After
public void tearDown() throws Exception {
waitingRequests.keySet().forEach(this::finishRequest);
// shutdown the Executor Service
if (handlerService.isShutdown() == false) {
handlerService.shutdown();
handlerService.awaitTermination(10, TimeUnit.SECONDS);
}
if (eventLoopService.isShutdown() == false) {
eventLoopService.shutdown();
eventLoopService.awaitTermination(10, TimeUnit.SECONDS);
}
terminateExecutorService(handlerService);
terminateExecutorService(eventLoopService);
super.tearDown();
}

private void terminateExecutorService(ExecutorService executorService) throws InterruptedException {
if (executorService.isShutdown() == false) {
executorService.shutdown();
assertTrue(executorService.awaitTermination(10, TimeUnit.SECONDS));
}
}

private CountDownLatch finishRequest(String url) {
waitingRequests.get(url).countDown();
return finishingRequests.get(url);
Expand All @@ -92,7 +97,7 @@ public void testThatPipeliningWorksWithFastSerializedRequests() throws Interrupt
final EmbeddedChannel embeddedChannel = makeEmbeddedChannelWithSimulatedWork(numberOfRequests);

for (int i = 0; i < numberOfRequests; i++) {
embeddedChannel.writeInbound(createHttpRequest("/" + String.valueOf(i)));
embeddedChannel.writeInbound(createHttpRequest("/" + i));
}

final List<CountDownLatch> latches = new ArrayList<>();
Expand Down Expand Up @@ -127,7 +132,7 @@ public void testThatPipeliningWorksWhenSlowRequestsInDifferentOrder() throws Int
final EmbeddedChannel embeddedChannel = makeEmbeddedChannelWithSimulatedWork(numberOfRequests);

for (int i = 0; i < numberOfRequests; i++) {
embeddedChannel.writeInbound(createHttpRequest("/" + String.valueOf(i)));
embeddedChannel.writeInbound(createHttpRequest("/" + i));
}

// random order execution
Expand Down Expand Up @@ -156,7 +161,7 @@ public void testThatPipeliningClosesConnectionWithTooManyEvents() throws Interru
final EmbeddedChannel embeddedChannel = makeEmbeddedChannelWithSimulatedWork(numberOfRequests);

for (int i = 0; i < 1 + numberOfRequests + 1; i++) {
embeddedChannel.writeInbound(createHttpRequest("/" + Integer.toString(i)));
embeddedChannel.writeInbound(createHttpRequest("/" + i));
}

final List<CountDownLatch> latches = new ArrayList<>();
Expand All @@ -178,7 +183,7 @@ public void testThatPipeliningClosesConnectionWithTooManyEvents() throws Interru
assertFalse(embeddedChannel.isOpen());
}

public void testPipeliningRequestsAreReleased() throws InterruptedException {
public void testPipeliningRequestsAreReleased() {
final int numberOfRequests = 10;
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(new Netty4HttpPipeliningHandler(logger, numberOfRequests + 1, null));

Expand All @@ -197,7 +202,7 @@ public void testPipeliningRequestsAreReleased() throws InterruptedException {
ChannelPromise promise = embeddedChannel.newPromise();
promises.add(promise);
Netty4HttpRequest pipelinedRequest = requests.get(i);
Netty4HttpResponse resp = pipelinedRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
Netty4FullHttpResponse resp = pipelinedRequest.createResponse(RestStatus.OK, BytesArray.EMPTY);
embeddedChannel.writeAndFlush(resp, promise);
}

Expand All @@ -211,6 +216,45 @@ public void testPipeliningRequestsAreReleased() throws InterruptedException {
}
}

public void testSmallFullResponsesAreSentDirectly() {
final List<Object> messagesSeen = new ArrayList<>();
final var embeddedChannel = new EmbeddedChannel(capturingHandler(messagesSeen), getTestHttpHandler());
embeddedChannel.writeInbound(createHttpRequest("/test"));
final Netty4HttpRequest request = embeddedChannel.readInbound();
final var maxSize = (int) NettyAllocator.suggestedMaxAllocationSize() / 2;
final var content = new ZeroBytesReference(between(0, maxSize));
final var response = request.createResponse(RestStatus.OK, content);
assertThat(response, instanceOf(FullHttpResponse.class));
final var promise = embeddedChannel.newPromise();
embeddedChannel.writeAndFlush(response, promise);
assertTrue(promise.isDone());
assertThat(messagesSeen, hasSize(1));
assertSame(response, messagesSeen.get(0));
}

public void testLargeFullResponsesAreSplit() {
final List<Object> messagesSeen = new ArrayList<>();
final var embeddedChannel = new EmbeddedChannel(capturingHandler(messagesSeen), getTestHttpHandler());
embeddedChannel.writeInbound(createHttpRequest("/test"));
final Netty4HttpRequest request = embeddedChannel.readInbound();
final var minSize = (int) NettyAllocator.suggestedMaxAllocationSize();
final var content = new ZeroBytesReference(between(minSize, minSize * 2));
final var response = request.createResponse(RestStatus.OK, content);
assertThat(response, instanceOf(FullHttpResponse.class));
final var promise = embeddedChannel.newPromise();
embeddedChannel.writeAndFlush(response, promise);
assertTrue(promise.isDone());
assertThat(messagesSeen, hasSize(3));
final var headersMessage = asInstanceOf(DefaultHttpResponse.class, messagesSeen.get(0));
assertEquals(RestStatus.OK.getStatus(), headersMessage.status().code());
assertThat(headersMessage, not(instanceOf(FullHttpResponse.class)));
final var chunk1 = asInstanceOf(DefaultHttpContent.class, messagesSeen.get(1));
final var chunk2 = asInstanceOf(DefaultLastHttpContent.class, messagesSeen.get(2));
assertEquals(content.length(), chunk1.content().readableBytes() + chunk2.content().readableBytes());
assertThat(chunk1, not(instanceOf(FullHttpResponse.class)));
assertThat(chunk2, not(instanceOf(FullHttpResponse.class)));
}

public void testDecoderErrorSurfacedAsNettyInboundError() {
final EmbeddedChannel embeddedChannel = new EmbeddedChannel(getTestHttpHandler());
// a request with a decoder error
Expand Down Expand Up @@ -304,7 +348,7 @@ public void testResumesSingleAfterChunkedMessage() {
assertTrue(promise1.isDone());
assertThat(messagesSeen, hasSize(chunks1 + 1 + 1));
assertChunkedMessageAtIndex(messagesSeen, 0, chunks1, chunk);
assertThat(messagesSeen.get(chunks1 + 1), instanceOf(Netty4HttpResponse.class));
assertThat(messagesSeen.get(chunks1 + 1), instanceOf(Netty4FullHttpResponse.class));
assertContentAtIndexEquals(messagesSeen, chunks1 + 1, single);
assertTrue(promise2.isDone());
}
Expand Down Expand Up @@ -339,7 +383,7 @@ public void testChunkedResumesAfterSingleMessage() {
embeddedChannel.flush();
assertTrue(promise1.isDone());
assertThat(messagesSeen, hasSize(chunks2 + 2));
assertThat(messagesSeen.get(0), instanceOf(Netty4HttpResponse.class));
assertThat(messagesSeen.get(0), instanceOf(Netty4FullHttpResponse.class));
assertChunkedMessageAtIndex(messagesSeen, 1, chunks2, chunk);
assertTrue(promise2.isDone());
}
Expand Down Expand Up @@ -377,7 +421,7 @@ public void testChunkedWithSmallChunksResumesAfterSingleMessage() {
embeddedChannel.flush();
assertTrue(promise1.isDone());
assertThat(messagesSeen, hasSize(chunks2 + 2));
assertThat(messagesSeen.get(0), instanceOf(Netty4HttpResponse.class));
assertThat(messagesSeen.get(0), instanceOf(Netty4FullHttpResponse.class));
assertChunkedMessageAtIndex(messagesSeen, 1, chunks2, chunk);
assertTrue(promise2.isDone());
}
Expand Down Expand Up @@ -410,7 +454,7 @@ public void testPipeliningRequestsAreReleasedAfterFailureOnChunked() {
for (Netty4HttpRequest request : requests) {
ChannelPromise promise = embeddedChannel.newPromise();
promises.add(promise);
Netty4HttpResponse resp = request.createResponse(RestStatus.OK, BytesArray.EMPTY);
Netty4FullHttpResponse resp = request.createResponse(RestStatus.OK, BytesArray.EMPTY);
embeddedChannel.write(resp, promise);
}
assertFalse(chunkedWritePromise.isDone());
Expand Down Expand Up @@ -525,7 +569,7 @@ protected void channelRead0(final ChannelHandlerContext ctx, Netty4HttpRequest r

handlerService.submit(() -> {
try {
waitingLatch.await(1000, TimeUnit.SECONDS);
assertTrue(waitingLatch.await(1000, TimeUnit.SECONDS));
final ChannelPromise promise = ctx.newPromise();
eventLoopService.submit(() -> {
ctx.write(httpResponse, promise);
Expand Down
Expand Up @@ -8,8 +8,6 @@

package org.elasticsearch.common.bytes;

import java.io.IOException;

import static org.hamcrest.Matchers.containsString;

public class ZeroBytesReferenceTests extends AbstractBytesReferenceTestCase {
Expand Down Expand Up @@ -39,9 +37,11 @@ public void testSliceToBytesRef() {
// ZeroBytesReference shifts offsets
}

public void testWriteWithIterator() throws IOException {
AssertionError error = expectThrows(AssertionError.class, () -> super.testWriteWithIterator());
assertThat(error.getMessage(), containsString("Internal pages from ZeroBytesReference must be zero"));
public void testWriteWithIterator() {
assertThat(
expectThrows(AssertionError.class, super::testWriteWithIterator).getMessage(),
containsString("Internal pages from ZeroBytesReference must be zero")
);
}

}

0 comments on commit cf67f5d

Please sign in to comment.