Skip to content

Commit

Permalink
Respond with same compression scheme received (#76514)
Browse files Browse the repository at this point in the history
This is related to #73497. Currently, we only use the configured
transport.compression_scheme setting when compressing a request or a
response. Additionally, the cluster.remote.*.compression_scheme
setting is ignored. This commit fixes this behavior by respecting the
per-cluster setting. Additionally, it resolves confusion around inbound
and outbound connections by always responding with the same scheme that
was received. This allows remote connections to have different schemes
than local connections.
  • Loading branch information
Tim-Brooks committed Aug 14, 2021
1 parent d356a4b commit f52ca3c
Show file tree
Hide file tree
Showing 19 changed files with 125 additions and 75 deletions.
4 changes: 3 additions & 1 deletion docs/reference/modules/transport.asciidoc
Original file line number Diff line number Diff line change
Expand Up @@ -200,4 +200,6 @@ request compression, you can set it on a per-remote cluster basis using the
The compression settings do not configure compression for responses. {es} will
compress a response if the inbound request was compressed--even when compression
is not enabled. Similarly, {es} will not compress a response if the inbound
request was uncompressed--even when compression is enabled.
request was uncompressed--even when compression is enabled. The compression
scheme used to compress a response will be the same scheme the remote node used
to compress the request.
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,11 @@ public ReleasableBytesReference pollDecompressedPage(boolean isEOS) {
}
}

@Override
public Compression.Scheme getScheme() {
return Compression.Scheme.DEFLATE;
}

@Override
public void close() {
inflater.end();
Expand Down
10 changes: 10 additions & 0 deletions server/src/main/java/org/elasticsearch/transport/Header.java
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ public class Header {
String actionName;
Tuple<Map<String, String>, Map<String, Set<String>>> headers;
Set<String> features;
private Compression.Scheme compressionScheme = null;

Header(int networkMessageSize, long requestId, byte status, Version version) {
this.networkMessageSize = networkMessageSize;
Expand Down Expand Up @@ -80,6 +81,10 @@ public String getActionName() {
return actionName;
}

public Compression.Scheme getCompressionScheme() {
return compressionScheme;
}

boolean needsToReadVariableHeader() {
return headers == null;
}
Expand Down Expand Up @@ -112,6 +117,11 @@ void finishParsingHeader(StreamInput input) throws IOException {
}
}

void setCompressionScheme(Compression.Scheme compressionScheme) {
assert isCompressed();
this.compressionScheme = compressionScheme;
}

@Override
public String toString() {
return "Header{" + networkMessageSize + "}{" + version + "}{" + requestId + "}{" + isRequest() + "}{" + isError() + "}{"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,13 @@ public void headerReceived(Header header) {
}
}

public void updateCompressionScheme(Compression.Scheme compressionScheme) {
ensureOpen();
assert isAggregating();
assert firstContent == null && contentAggregation == null;
currentHeader.setCompressionScheme(compressionScheme);
}

public void aggregate(ReleasableBytesReference content) {
ensureOpen();
assert isAggregating();
Expand Down Expand Up @@ -112,6 +119,7 @@ public InboundMessage finishAggregation() throws IOException {
success = true;
return new InboundMessage(aggregated.getHeader(), aggregationException);
} else {
assert uncompressedOrSchemeDefined(aggregated.getHeader());
success = true;
return aggregated;
}
Expand Down Expand Up @@ -188,6 +196,10 @@ private void initializeRequestState() {
}
}

private static boolean uncompressedOrSchemeDefined(Header header) {
return header.isCompressed() == (header.getCompressionScheme() != null);
}

private void checkBreaker(final Header header, final int contentLength, final BreakerControl breakerControl) {
if (header.isRequest() == false) {
return;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,7 @@ public int internalDecode(ReleasableBytesReference reference, Consumer<Object> f
return 0;
} else {
this.decompressor = decompressor;
fragmentConsumer.accept(this.decompressor.getScheme());
}
}
int remainingToConsume = totalNetworkSize - bytesConsumed;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
final StreamInput stream = namedWriteableStream(message.openOrGetStreamInput());
assertRemoteVersion(stream, header.getVersion());
final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
header.getFeatures(), header.isCompressed(), header.isHandshake(), message.takeBreakerReleaseControl());
header.getFeatures(), header.getCompressionScheme(), header.isHandshake(), message.takeBreakerReleaseControl());
try {
handshaker.handleHandshake(transportChannel, requestId, stream);
} catch (Exception e) {
Expand All @@ -175,7 +175,7 @@ private <T extends TransportRequest> void handleRequest(TcpChannel channel, Head
}
} else {
final TransportChannel transportChannel = new TcpTransportChannel(outboundHandler, channel, action, requestId, version,
header.getFeatures(), header.isCompressed(), header.isHandshake(), message.takeBreakerReleaseControl());
header.getFeatures(), header.getCompressionScheme(), header.isHandshake(), message.takeBreakerReleaseControl());
try {
messageListener.onRequestReceived(requestId, action);
if (message.isShortCircuit()) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ private void forwardFragments(TcpChannel channel, ArrayList<Object> fragments) t
if (fragment instanceof Header) {
assert aggregator.isAggregating() == false;
aggregator.headerReceived((Header) fragment);
} else if (fragment instanceof Compression.Scheme) {
assert aggregator.isAggregating();
aggregator.updateCompressionScheme((Compression.Scheme) fragment);
} else if (fragment == InboundDecoder.PING) {
assert aggregator.isAggregating() == false;
messageHandler.accept(channel, PING_MESSAGE);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ public ReleasableBytesReference pollDecompressedPage(boolean isEOS) {
}
}

@Override
public Compression.Scheme getScheme() {
return Compression.Scheme.LZ4;
}

@Override
public void close() {
for (Recycler.V<byte[]> page : pages) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,21 +40,19 @@ final class OutboundHandler {
private final StatsTracker statsTracker;
private final ThreadPool threadPool;
private final BigArrays bigArrays;
private final Compression.Scheme configuredCompressionScheme;

private volatile long slowLogThresholdMs = Long.MAX_VALUE;

private volatile TransportMessageListener messageListener = TransportMessageListener.NOOP_LISTENER;

OutboundHandler(String nodeName, Version version, String[] features, StatsTracker statsTracker, ThreadPool threadPool,
BigArrays bigArrays, Compression.Scheme compressionScheme) {
BigArrays bigArrays) {
this.nodeName = nodeName;
this.version = version;
this.features = features;
this.statsTracker = statsTracker;
this.threadPool = threadPool;
this.bigArrays = bigArrays;
this.configuredCompressionScheme = compressionScheme;
}

void setSlowLogThreshold(TimeValue slowLogThreshold) {
Expand All @@ -71,14 +69,8 @@ void sendBytes(TcpChannel channel, BytesReference bytes, ActionListener<Void> li
*/
void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long requestId, final String action,
final TransportRequest request, final TransportRequestOptions options, final Version channelVersion,
final boolean compressRequest, final boolean isHandshake) throws IOException, TransportException {
final Compression.Scheme compressionScheme, final boolean isHandshake) throws IOException, TransportException {
Version version = Version.min(this.version, channelVersion);
final Compression.Scheme compressionScheme;
if (compressRequest) {
compressionScheme = configuredCompressionScheme;
} else {
compressionScheme = null;
}
OutboundMessage.Request message =
new OutboundMessage.Request(threadPool.getThreadContext(), features, request, version, action, requestId, isHandshake,
compressionScheme);
Expand All @@ -103,15 +95,10 @@ void sendRequest(final DiscoveryNode node, final TcpChannel channel, final long
* @see #sendErrorResponse(Version, Set, TcpChannel, long, String, Exception) for sending error responses
*/
void sendResponse(final Version nodeVersion, final Set<String> features, final TcpChannel channel, final long requestId,
final String action, final TransportResponse response, final boolean compressResponse, final boolean isHandshake)
final String action, final TransportResponse response, final Compression.Scheme compressionScheme,
final boolean isHandshake)
throws IOException {
Version version = Version.min(this.version, nodeVersion);
final Compression.Scheme compressionScheme;
if (compressResponse) {
compressionScheme = configuredCompressionScheme;
} else {
compressionScheme = null;
}
OutboundMessage.Response message = new OutboundMessage.Response(threadPool.getThreadContext(), features, response, version,
requestId, isHandshake, compressionScheme);
ActionListener<Void> listener = ActionListener.wrap(() -> messageListener.onResponseSent(requestId, action, response));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -135,7 +135,6 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P
this.pageCacheRecycler = pageCacheRecycler;
this.circuitBreakerService = circuitBreakerService;
this.networkService = networkService;
Compression.Scheme compressionScheme = TransportSettings.TRANSPORT_COMPRESSION_SCHEME.get(settings);
String nodeName = Node.NODE_NAME_SETTING.get(settings);
final Settings defaultFeatures = TransportSettings.DEFAULT_FEATURES_SETTING.get(settings);
String[] features;
Expand All @@ -152,11 +151,11 @@ public TcpTransport(Settings settings, Version version, ThreadPool threadPool, P
}
BigArrays bigArrays = new BigArrays(pageCacheRecycler, circuitBreakerService, CircuitBreaker.IN_FLIGHT_REQUESTS);

this.outboundHandler = new OutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays, compressionScheme);
this.outboundHandler = new OutboundHandler(nodeName, version, features, statsTracker, threadPool, bigArrays);
this.handshaker = new TransportHandshaker(version, threadPool,
(node, channel, requestId, v) -> outboundHandler.sendRequest(node, channel, requestId,
TransportHandshaker.HANDSHAKE_ACTION_NAME, new TransportHandshaker.HandshakeRequest(version),
TransportRequestOptions.EMPTY, v, false, true));
TransportRequestOptions.EMPTY, v, null, true));
this.keepAlive = new TransportKeepAlive(threadPool, this.outboundHandler::sendBytes);
this.inboundHandler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive,
requestHandlers, responseHandlers);
Expand Down Expand Up @@ -200,6 +199,7 @@ public final class NodeChannels extends CloseableConnection {
private final DiscoveryNode node;
private final Version version;
private final Compression.Enabled compress;
private final Compression.Scheme compressionScheme;
private final AtomicBoolean isClosing = new AtomicBoolean(false);

NodeChannels(DiscoveryNode node, List<TcpChannel> channels, ConnectionProfile connectionProfile, Version handshakeVersion) {
Expand All @@ -214,6 +214,7 @@ public final class NodeChannels extends CloseableConnection {
}
version = handshakeVersion;
compress = connectionProfile.getCompressionEnabled();
compressionScheme = connectionProfile.getCompressionScheme();
}

@Override
Expand Down Expand Up @@ -261,11 +262,12 @@ public void sendRequest(long requestId, String action, TransportRequest request,
// We compress if total transport compression is enabled or if indexing_data transport compression
// is enabled and the request is a RawIndexingDataTransportRequest which indicates it should be
// compressed.
boolean shouldCompress = compress == Compression.Enabled.TRUE ||
final boolean shouldCompress = compress == Compression.Enabled.TRUE ||
(compress == Compression.Enabled.INDEXING_DATA
&& request instanceof RawIndexingDataTransportRequest
&& ((RawIndexingDataTransportRequest) request).isRawIndexingData());
outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), shouldCompress, false);
final Compression.Scheme schemeToUse = shouldCompress ? compressionScheme : null;
outboundHandler.sendRequest(node, channel, requestId, action, request, options, getVersion(), schemeToUse, false);
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,19 +24,19 @@ public final class TcpTransportChannel implements TransportChannel {
private final long requestId;
private final Version version;
private final Set<String> features;
private final boolean compressResponse;
private final Compression.Scheme compressionScheme;
private final boolean isHandshake;
private final Releasable breakerRelease;

TcpTransportChannel(OutboundHandler outboundHandler, TcpChannel channel, String action, long requestId, Version version,
Set<String> features, boolean compressResponse, boolean isHandshake, Releasable breakerRelease) {
Set<String> features, Compression.Scheme compressionScheme, boolean isHandshake, Releasable breakerRelease) {
this.version = version;
this.features = features;
this.channel = channel;
this.outboundHandler = outboundHandler;
this.action = action;
this.requestId = requestId;
this.compressResponse = compressResponse;
this.compressionScheme = compressionScheme;
this.isHandshake = isHandshake;
this.breakerRelease = breakerRelease;
}
Expand All @@ -49,7 +49,7 @@ public String getProfileName() {
@Override
public void sendResponse(TransportResponse response) throws IOException {
try {
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressResponse, isHandshake);
outboundHandler.sendResponse(version, features, channel, requestId, action, response, compressionScheme, isHandshake);
} finally {
release(false);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,8 @@ public interface TransportDecompressor extends Releasable {

ReleasableBytesReference pollDecompressedPage(boolean isEOS);

Compression.Scheme getScheme();

@Override
void close();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,11 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException {
final BytesReference bytes2 = totalBytes.slice(bytesConsumed, totalBytes.length() - bytesConsumed);
final ReleasableBytesReference releasable2 = ReleasableBytesReference.wrap(bytes2);
int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
assertEquals(2, fragments.size());
if (compressionScheme == null) {
assertEquals(2, fragments.size());
} else {
assertEquals(3, fragments.size());
}
assertEquals(InboundDecoder.END_CONTENT, fragments.get(fragments.size() - 1));
assertEquals(totalBytes.length() - bytesConsumed, bytesConsumed2);
}
Expand Down Expand Up @@ -195,7 +199,7 @@ public void testCompressedDecode() throws IOException {
final BytesReference totalBytes = message.serialize(new BytesStreamOutput());
final BytesStreamOutput out = new BytesStreamOutput();
transportMessage.writeTo(out);
final BytesReference uncompressedBytes =out.bytes();
final BytesReference uncompressedBytes = out.bytes();
int totalHeaderSize = TcpHeader.headerSize(Version.CURRENT) + totalBytes.getInt(TcpHeader.VARIABLE_HEADER_SIZE_POSITION);

InboundDecoder decoder = new InboundDecoder(Version.CURRENT, PageCacheRecycler.NON_RECYCLING_INSTANCE);
Expand Down Expand Up @@ -226,9 +230,11 @@ public void testCompressedDecode() throws IOException {
int bytesConsumed2 = decoder.decode(releasable2, fragments::add);
assertEquals(totalBytes.length() - totalHeaderSize, bytesConsumed2);

final Object content = fragments.get(0);
final Object endMarker = fragments.get(1);
final Object compressionScheme = fragments.get(0);
final Object content = fragments.get(1);
final Object endMarker = fragments.get(2);

assertEquals(scheme, compressionScheme);
assertEquals(uncompressedBytes, content);
// Ref count is not incremented since the bytes are immediately consumed on decompression
assertEquals(1, releasable2.refCount());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ public void setUp() throws Exception {
TransportHandshaker handshaker = new TransportHandshaker(version, threadPool, (n, c, r, v) -> {});
TransportKeepAlive keepAlive = new TransportKeepAlive(threadPool, TcpChannel::sendMessage);
OutboundHandler outboundHandler = new OutboundHandler("node", version, new String[0], new StatsTracker(), threadPool,
BigArrays.NON_RECYCLING_INSTANCE, randomFrom(Compression.Scheme.DEFLATE, Compression.Scheme.LZ4));
BigArrays.NON_RECYCLING_INSTANCE);
requestHandlers = new Transport.RequestHandlers();
responseHandlers = new Transport.ResponseHandlers();
handler = new InboundHandler(threadPool, outboundHandler, namedWriteableRegistry, handshaker, keepAlive, requestHandlers,
Expand Down

0 comments on commit f52ca3c

Please sign in to comment.