Skip to content

Commit

Permalink
Fix DefaultRestChannel Corrupting Shared Buffers on Serialialization …
Browse files Browse the repository at this point in the history
…Issues (#72274)

We must not reset the shared buffer after it has been used (can happen in error handling in `RestController#sendResponse`).
There is never a good reason to reset a pooled bytes output either and the behavior isn't clearly defined so this commit
disables the operation as it had unintended side effects.
  • Loading branch information
original-brownbear committed Apr 27, 2021
1 parent f992e47 commit 5edc0d3
Show file tree
Hide file tree
Showing 7 changed files with 48 additions and 39 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,11 @@ public ReleasableBytesStreamOutput(int expectedSize, BigArrays bigArrays) {
public void close() {
Releasables.close(bytes);
}

@Override
public void reset() {
assert false;
// not supported, close and create a new instance instead
throw new UnsupportedOperationException("must not reuse a pooled bytes backed stream");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@ public void sendResponse(RestResponse restResponse) {
if (content instanceof Releasable) {
toClose.add((Releasable) content);
}
toClose.add(this::releaseOutputBuffer);

BytesReference finalContent = content;
try {
Expand Down Expand Up @@ -122,11 +123,6 @@ public void sendResponse(RestResponse restResponse) {

addCookies(httpResponse);

BytesStreamOutput bytesStreamOutput = bytesOutputOrNull();
if (bytesStreamOutput instanceof ReleasableBytesStreamOutput) {
toClose.add((Releasable) bytesStreamOutput);
}

ActionListener<Void> listener = ActionListener.wrap(() -> Releasables.close(toClose));
httpChannel.sendResponse(httpResponse, listener);
success = true;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
import org.elasticsearch.common.UUIDs;
import org.elasticsearch.common.bytes.BytesArray;
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.io.stream.BytesStreamOutput;
import org.elasticsearch.common.io.stream.ReleasableBytesStreamOutput;
import org.elasticsearch.common.io.stream.StreamInput;
import org.elasticsearch.common.io.stream.StreamOutput;
Expand Down Expand Up @@ -1550,26 +1551,15 @@ static Translog.Operation readOperation(BufferedChecksumStreamInput in) throws I
* use {@link #readOperations(StreamInput, String)} to read it back.
*/
public static void writeOperations(StreamOutput outStream, List<Operation> toWrite) throws IOException {
final ReleasableBytesStreamOutput out = new ReleasableBytesStreamOutput(BigArrays.NON_RECYCLING_INSTANCE);
try {
outStream.writeInt(toWrite.size());
final BufferedChecksumStreamOutput checksumStreamOutput = new BufferedChecksumStreamOutput(out);
for (Operation op : toWrite) {
out.reset();
final long start = out.position();
out.skip(Integer.BYTES);
writeOperationNoSize(checksumStreamOutput, op);
long end = out.position();
int operationSize = (int) (out.position() - Integer.BYTES - start);
out.seek(start);
out.writeInt(operationSize);
out.seek(end);
out.bytes().writeTo(outStream);
}
} finally {
Releasables.close(out);
final BytesStreamOutput out = new BytesStreamOutput();
outStream.writeInt(toWrite.size());
final BufferedChecksumStreamOutput checksumStreamOutput = new BufferedChecksumStreamOutput(out);
for (Operation op : toWrite) {
out.reset();
writeOperationNoSize(checksumStreamOutput, op);
outStream.writeInt(Math.toIntExact(out.position()));
out.bytes().writeTo(outStream);
}

}

public static void writeOperationNoSize(BufferedChecksumStreamOutput out, Translog.Operation op) throws IOException {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
*/
package org.elasticsearch.rest;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.elasticsearch.common.Nullable;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.io.Streams;
Expand All @@ -27,6 +29,8 @@

public abstract class AbstractRestChannel implements RestChannel {

private static final Logger logger = LogManager.getLogger(AbstractRestChannel.class);

private static final Predicate<String> INCLUDE_FILTER = f -> f.charAt(0) != '-';
private static final Predicate<String> EXCLUDE_FILTER = INCLUDE_FILTER.negate();

Expand Down Expand Up @@ -137,25 +141,30 @@ public XContentBuilder newBuilder(@Nullable XContentType requestContentType, @Nu

/**
* A channel level bytes output that can be reused. The bytes output is lazily instantiated
* by a call to {@link #newBytesOutput()}. Once the stream is created, it gets reset on each
* call to this method.
* by a call to {@link #newBytesOutput()}. This method should only be called once per request.
*/
@Override
public final BytesStreamOutput bytesOutput() {
if (bytesOut == null) {
bytesOut = newBytesOutput();
} else {
bytesOut.reset();
if (bytesOut != null) {
// fallback in case of encountering a bug, release the existing buffer if any (to avoid leaking memory) and acquire a new one
// to send out an error response
assert false : "getting here is always a bug";
logger.error("channel handling [{}] reused", request.rawPath());
releaseOutputBuffer();
}
bytesOut = newBytesOutput();
return bytesOut;
}

/**
* An accessor to the raw value of the channel bytes output. This method will not instantiate
* a new stream if one does not exist and this method will not reset the stream.
* Releases the current output buffer for this channel. Must be called after the buffer derived from {@link #bytesOutput} is no longer
* needed.
*/
protected final BytesStreamOutput bytesOutputOrNull() {
return bytesOut;
protected final void releaseOutputBuffer() {
if (bytesOut != null) {
bytesOut.close();
bytesOut = null;
}
}

protected BytesStreamOutput newBytesOutput() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -111,16 +111,15 @@ public void testErrorTrace() throws Exception {

public void testGuessRootCause() throws IOException {
RestRequest request = new FakeRestRequest();
RestChannel channel = new DetailedExceptionRestChannel(request);
{
Exception e = new ElasticsearchException("an error occurred reading data", new FileNotFoundException("/foo/bar"));
BytesRestResponse response = new BytesRestResponse(channel, e);
BytesRestResponse response = new BytesRestResponse(new DetailedExceptionRestChannel(request), e);
String text = response.content().utf8ToString();
assertThat(text, containsString("{\"root_cause\":[{\"type\":\"exception\",\"reason\":\"an error occurred reading data\"}]"));
}
{
Exception e = new FileNotFoundException("/foo/bar");
BytesRestResponse response = new BytesRestResponse(channel, e);
BytesRestResponse response = new BytesRestResponse(new DetailedExceptionRestChannel(request), e);
String text = response.content().utf8ToString();
assertThat(text, containsString("{\"root_cause\":[{\"type\":\"file_not_found_exception\",\"reason\":\"/foo/bar\"}]"));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -780,7 +780,12 @@ protected ExceptionThrowingChannel(RestRequest request, boolean detailedErrorsEn

@Override
public void sendResponse(RestResponse response) {
throw new IllegalStateException("always throwing an exception for testing");
try {
throw new IllegalStateException("always throwing an exception for testing");
} finally {
// the production implementation in DefaultRestChannel always releases the output buffer, so we must too
releaseOutputBuffer();
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
package org.elasticsearch.xpack.security.rest.action.service;

import org.elasticsearch.client.node.NodeClient;
import org.elasticsearch.cluster.ClusterName;
import org.elasticsearch.common.Strings;
import org.elasticsearch.common.settings.Settings;
import org.elasticsearch.common.xcontent.NamedXContentRegistry;
Expand Down Expand Up @@ -50,7 +51,9 @@ public void init() {
verifyingClient.setExecuteVerifier(((actionType, actionRequest) -> {
assertThat(actionRequest, instanceOf(ClearSecurityCacheRequest.class));
requestHolder.set((ClearSecurityCacheRequest) actionRequest);
return mock(ClearSecurityCacheResponse.class);
final ClearSecurityCacheResponse response = mock(ClearSecurityCacheResponse.class);
when(response.getClusterName()).thenReturn(new ClusterName(""));
return response;
}));
}

Expand Down

0 comments on commit 5edc0d3

Please sign in to comment.