Skip to content

Commit

Permalink
Avoid StackOverflowError if write circular reference exception (#54147)
Browse files Browse the repository at this point in the history
We should never write a circular reference exception as we will fail a
node with StackOverflowError. However, we have one in #53589.
I tried but failed to find its location. With this commit, we will avoid
StackOverflowError in production and detect circular exceptions in
tests.

Closes #53589
  • Loading branch information
dnhatn committed Apr 4, 2020
1 parent 91c8a23 commit 8cde744
Show file tree
Hide file tree
Showing 3 changed files with 46 additions and 5 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -276,7 +276,7 @@ public Throwable getRootCause() {
public void writeTo(StreamOutput out) throws IOException {
out.writeOptionalString(this.getMessage());
out.writeException(this.getCause());
writeStackTraces(this, out);
writeStackTraces(this, out, StreamOutput::writeException);
out.writeMapOfLists(headers, StreamOutput::writeString, StreamOutput::writeString);
out.writeMapOfLists(metadata, StreamOutput::writeString, StreamOutput::writeString);
}
Expand Down Expand Up @@ -715,7 +715,8 @@ public static <T extends Throwable> T readStackTrace(T throwable, StreamInput in
/**
* Serializes the given exceptions stacktrace elements as well as it's suppressed exceptions to the given output stream.
*/
public static <T extends Throwable> T writeStackTraces(T throwable, StreamOutput out) throws IOException {
public static <T extends Throwable> T writeStackTraces(T throwable, StreamOutput out,
Writer<Throwable> exceptionWriter) throws IOException {
StackTraceElement[] stackTrace = throwable.getStackTrace();
out.writeVInt(stackTrace.length);
for (StackTraceElement element : stackTrace) {
Expand All @@ -727,7 +728,7 @@ public static <T extends Throwable> T writeStackTraces(T throwable, StreamOutput
Throwable[] suppressed = throwable.getSuppressed();
out.writeVInt(suppressed.length);
for (Throwable t : suppressed) {
out.writeException(t);
exceptionWriter.write(out, t);
}
return throwable;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,7 @@
public abstract class StreamOutput extends OutputStream {

private static final Map<TimeUnit, Byte> TIME_UNIT_BYTE_MAP;
private static final int MAX_NESTED_EXCEPTION_LEVEL = 100;

static {
final Map<TimeUnit, Byte> timeUnitByteMap = new EnumMap<>(TimeUnit.class);
Expand Down Expand Up @@ -901,8 +902,15 @@ public void writeOptionalWriteable(@Nullable Writeable writeable) throws IOExcep
}

public void writeException(Throwable throwable) throws IOException {
writeException(throwable, throwable, 0);
}

private void writeException(Throwable rootException, Throwable throwable, int nestedLevel) throws IOException {
if (throwable == null) {
writeBoolean(false);
} else if (nestedLevel > MAX_NESTED_EXCEPTION_LEVEL) {
assert failOnTooManyNestedExceptions(rootException);
writeException(new IllegalStateException("too many nested exceptions"));
} else {
writeBoolean(true);
boolean writeCause = true;
Expand Down Expand Up @@ -1011,12 +1019,16 @@ public void writeException(Throwable throwable) throws IOException {
writeOptionalString(throwable.getMessage());
}
if (writeCause) {
writeException(throwable.getCause());
writeException(rootException, throwable.getCause(), nestedLevel + 1);
}
ElasticsearchException.writeStackTraces(throwable, this);
ElasticsearchException.writeStackTraces(throwable, this, (o, t) -> o.writeException(rootException, t, nestedLevel + 1));
}
}

boolean failOnTooManyNestedExceptions(Throwable throwable) {
throw new AssertionError("too many nested exceptions", throwable);
}

/**
* Writes a {@link NamedWriteable} to the current stream, by first writing its name and then the object itself
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@

package org.elasticsearch.common.io.stream;

import org.apache.lucene.store.AlreadyClosedException;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.Constants;
import org.elasticsearch.common.bytes.BytesArray;
Expand Down Expand Up @@ -46,10 +47,13 @@
import java.util.stream.IntStream;

import static org.hamcrest.Matchers.closeTo;
import static org.hamcrest.Matchers.containsString;
import static org.hamcrest.Matchers.endsWith;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.hasSize;
import static org.hamcrest.Matchers.instanceOf;
import static org.hamcrest.Matchers.is;
import static org.hamcrest.Matchers.sameInstance;

/**
* Tests for {@link BytesStreamOutput} paging behaviour.
Expand Down Expand Up @@ -835,4 +839,28 @@ public void testTimeValueSerialize() throws Exception {
assertEqualityAfterSerialize(timeValue, 1 + out.bytes().length());
}

public void testWriteCircularReferenceException() throws IOException {
IOException rootEx = new IOException("disk broken");
AlreadyClosedException ace = new AlreadyClosedException("closed", rootEx);
rootEx.addSuppressed(ace); // circular reference

BytesStreamOutput testOut = new BytesStreamOutput();
AssertionError error = expectThrows(AssertionError.class, () -> testOut.writeException(rootEx));
assertThat(error.getMessage(), containsString("too many nested exceptions"));
assertThat(error.getCause(), equalTo(rootEx));

BytesStreamOutput prodOut = new BytesStreamOutput() {
@Override
boolean failOnTooManyNestedExceptions(Throwable throwable) {
assertThat(throwable, sameInstance(rootEx));
return true;
}
};
prodOut.writeException(rootEx);
StreamInput in = prodOut.bytes().streamInput();
Exception newEx = in.readException();
assertThat(newEx, instanceOf(IOException.class));
assertThat(newEx.getMessage(), equalTo("disk broken"));
assertArrayEquals(newEx.getStackTrace(), rootEx.getStackTrace());
}
}

0 comments on commit 8cde744

Please sign in to comment.