Skip to content

Commit

Permalink
Delegate Ref Counting to ByteBuf in Netty Transport (#81096)
Browse files Browse the repository at this point in the history
Tracking down recent memory leaks was made unnecessarily hard
by wrapping the `ByteBuf` ref couting with our own counter. This
way, we would not record the increments and decrements on the Netty
leak tracker, making it useless as far as identifying the concrete
source of a request with the logged leak only containing touch points
up until our inbound handler code.
  • Loading branch information
original-brownbear committed Nov 29, 2021
1 parent ca65718 commit 256521e
Show file tree
Hide file tree
Showing 7 changed files with 90 additions and 49 deletions.
4 changes: 2 additions & 2 deletions libs/nio/src/main/java/org/elasticsearch/nio/Page.java
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ public Page(ByteBuffer byteBuffer, Releasable closeable) {
}

private Page(ByteBuffer byteBuffer, RefCountedCloseable refCountedCloseable) {
assert refCountedCloseable.refCount() > 0;
assert refCountedCloseable.hasReferences();
this.byteBuffer = byteBuffer;
this.refCountedCloseable = refCountedCloseable;
}
Expand All @@ -51,7 +51,7 @@ public Page duplicate() {
* @return the byte buffer
*/
public ByteBuffer byteBuffer() {
assert refCountedCloseable.refCount() > 0;
assert refCountedCloseable.hasReferences();
return byteBuffer;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
import org.elasticsearch.common.bytes.BytesReference;
import org.elasticsearch.common.bytes.ReleasableBytesReference;
import org.elasticsearch.common.recycler.Recycler;
import org.elasticsearch.core.RefCounted;
import org.elasticsearch.core.Releasables;
import org.elasticsearch.threadpool.ThreadPool;
import org.elasticsearch.transport.InboundPipeline;
Expand Down Expand Up @@ -68,7 +69,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
final ByteBuf buffer = (ByteBuf) msg;
Netty4TcpChannel channel = ctx.channel().attr(Netty4Transport.CHANNEL_KEY).get();
final BytesReference wrapped = Netty4Utils.toBytesReference(buffer);
try (ReleasableBytesReference reference = new ReleasableBytesReference(wrapped, buffer::release)) {
try (ReleasableBytesReference reference = new ReleasableBytesReference(wrapped, new ByteBufRefCounted(buffer))) {
pipeline.handleBytes(channel, reference);
}
}
Expand Down Expand Up @@ -211,4 +212,43 @@ void failAsClosedChannel() {
buf.release();
}
}

private static final class ByteBufRefCounted implements RefCounted {

private final ByteBuf buffer;

ByteBufRefCounted(ByteBuf buffer) {
this.buffer = buffer;
}

@Override
public void incRef() {
buffer.retain();
}

@Override
public boolean tryIncRef() {
if (hasReferences() == false) {
return false;
}
try {
buffer.retain();
} catch (RuntimeException e) {
assert hasReferences() == false;
return false;
}
return true;
}

@Override
public boolean decRef() {
return buffer.release();
}

@Override
public boolean hasReferences() {
return buffer.refCnt() > 0;
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ public final class ReleasableBytesReference implements RefCounted, Releasable, B
private static final ReleasableBytesReference EMPTY = new ReleasableBytesReference(BytesArray.EMPTY, NO_OP);

private final BytesReference delegate;
private final AbstractRefCounted refCounted;
private final RefCounted refCounted;

public static ReleasableBytesReference empty() {
EMPTY.incRef();
Expand All @@ -42,21 +42,17 @@ public ReleasableBytesReference(BytesReference delegate, Releasable releasable)
this(delegate, new RefCountedReleasable(releasable));
}

public ReleasableBytesReference(BytesReference delegate, AbstractRefCounted refCounted) {
public ReleasableBytesReference(BytesReference delegate, RefCounted refCounted) {
this.delegate = delegate;
this.refCounted = refCounted;
assert refCounted.refCount() > 0;
assert refCounted.hasReferences();
}

public static ReleasableBytesReference wrap(BytesReference reference) {
assert reference instanceof ReleasableBytesReference == false : "use #retain() instead of #wrap() on a " + reference.getClass();
return reference.length() == 0 ? empty() : new ReleasableBytesReference(reference, NO_OP);
}

public int refCount() {
return refCounted.refCount();
}

@Override
public void incRef() {
refCounted.incRef();
Expand Down Expand Up @@ -98,19 +94,19 @@ public void close() {

@Override
public byte get(int index) {
assert refCount() > 0;
assert hasReferences();
return delegate.get(index);
}

@Override
public int getInt(int index) {
assert refCount() > 0;
assert hasReferences();
return delegate.getInt(index);
}

@Override
public int indexOf(byte marker, int from) {
assert refCount() > 0;
assert hasReferences();
return delegate.indexOf(marker, from);
}

Expand All @@ -121,7 +117,7 @@ public int length() {

@Override
public BytesReference slice(int from, int length) {
assert refCount() > 0;
assert hasReferences();
return delegate.slice(from, length);
}

Expand All @@ -132,7 +128,7 @@ public long ramBytesUsed() {

@Override
public StreamInput streamInput() throws IOException {
assert refCount() > 0;
assert hasReferences();
return new BytesReferenceStreamInput(this) {
@Override
public ReleasableBytesReference readReleasableBytesReference() throws IOException {
Expand All @@ -148,37 +144,37 @@ public ReleasableBytesReference readReleasableBytesReference() throws IOExceptio

@Override
public void writeTo(OutputStream os) throws IOException {
assert refCount() > 0;
assert hasReferences();
delegate.writeTo(os);
}

@Override
public String utf8ToString() {
assert refCount() > 0;
assert hasReferences();
return delegate.utf8ToString();
}

@Override
public BytesRef toBytesRef() {
assert refCount() > 0;
assert hasReferences();
return delegate.toBytesRef();
}

@Override
public BytesRefIterator iterator() {
assert refCount() > 0;
assert hasReferences();
return delegate.iterator();
}

@Override
public int compareTo(BytesReference o) {
assert refCount() > 0;
assert hasReferences();
return delegate.compareTo(o);
}

@Override
public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException {
assert refCount() > 0;
assert hasReferences();
return delegate.toXContent(builder, params);
}

Expand All @@ -189,31 +185,31 @@ public boolean isFragment() {

@Override
public boolean equals(Object obj) {
assert refCount() > 0;
assert hasReferences();
return delegate.equals(obj);
}

@Override
public int hashCode() {
assert refCount() > 0;
assert hasReferences();
return delegate.hashCode();
}

@Override
public boolean hasArray() {
assert refCount() > 0;
assert hasReferences();
return delegate.hasArray();
}

@Override
public byte[] array() {
assert refCount() > 0;
assert hasReferences();
return delegate.array();
}

@Override
public int arrayOffset() {
assert refCount() > 0;
assert hasReferences();
return delegate.arrayOffset();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -91,11 +91,11 @@ public void testInboundAggregation() throws IOException {
assertThat(aggregated.getHeader().getRequestId(), equalTo(requestId));
assertThat(aggregated.getHeader().getVersion(), equalTo(Version.CURRENT));
for (ReleasableBytesReference reference : references) {
assertEquals(1, reference.refCount());
assertTrue(reference.hasReferences());
}
aggregated.close();
for (ReleasableBytesReference reference : references) {
assertEquals(0, reference.refCount());
assertFalse(reference.hasReferences());
}
}

Expand All @@ -111,7 +111,7 @@ public void testInboundUnknownAction() throws IOException {
final ReleasableBytesReference content = ReleasableBytesReference.wrap(bytes);
aggregator.aggregate(content);
content.close();
assertEquals(0, content.refCount());
assertFalse(content.hasReferences());

// Signal EOS
InboundMessage aggregated = aggregator.finishAggregation();
Expand Down Expand Up @@ -139,7 +139,7 @@ public void testCircuitBreak() throws IOException {
// Signal EOS
InboundMessage aggregated1 = aggregator.finishAggregation();

assertEquals(0, content1.refCount());
assertFalse(content1.hasReferences());
assertThat(aggregated1, notNullValue());
assertTrue(aggregated1.isShortCircuit());
assertThat(aggregated1.getException(), instanceOf(CircuitBreakingException.class));
Expand All @@ -158,7 +158,7 @@ public void testCircuitBreak() throws IOException {
// Signal EOS
InboundMessage aggregated2 = aggregator.finishAggregation();

assertEquals(1, content2.refCount());
assertTrue(content2.hasReferences());
assertThat(aggregated2, notNullValue());
assertFalse(aggregated2.isShortCircuit());

Expand All @@ -177,7 +177,7 @@ public void testCircuitBreak() throws IOException {
// Signal EOS
InboundMessage aggregated3 = aggregator.finishAggregation();

assertEquals(1, content3.refCount());
assertTrue(content3.hasReferences());
assertThat(aggregated3, notNullValue());
assertFalse(aggregated3.isShortCircuit());
}
Expand Down Expand Up @@ -211,7 +211,7 @@ public void testCloseWillCloseContent() {
aggregator.close();

for (ReleasableBytesReference reference : references) {
assertEquals(0, reference.refCount());
assertFalse(reference.hasReferences());
}
}

Expand Down Expand Up @@ -244,10 +244,10 @@ public void testFinishAggregationWillFinishHeader() throws IOException {
assertFalse(header.needsToReadVariableHeader());
assertEquals(actionName, header.getActionName());
if (unknownAction) {
assertEquals(0, content.refCount());
assertFalse(content.hasReferences());
assertTrue(aggregated.isShortCircuit());
} else {
assertEquals(1, content.refCount());
assertTrue(content.hasReferences());
assertFalse(aggregated.isShortCircuit());
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ public void testDecode() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -108,7 +108,10 @@ public void testDecode() throws IOException {

assertEquals(messageBytes, content);
// Ref count is incremented since the bytes are forwarded as a fragment
assertEquals(2, releasable2.refCount());
assertTrue(releasable2.hasReferences());
releasable2.decRef();
assertTrue(releasable2.hasReferences());
assertTrue(releasable2.decRef());
assertEquals(InboundDecoder.END_CONTENT, endMarker);
}

Expand Down Expand Up @@ -141,7 +144,7 @@ public void testDecodePreHeaderSizeVariableInt() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(partialHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -198,7 +201,7 @@ public void testDecodeHandshakeCompatibility() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -247,7 +250,7 @@ public void testCompressedDecode() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(totalBytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -279,7 +282,7 @@ public void testCompressedDecode() throws IOException {
assertThat(content, instanceOf(ReleasableBytesReference.class));
((ReleasableBytesReference) content).close();
// Ref count is not incremented since the bytes are immediately consumed on decompression
assertEquals(1, releasable2.refCount());
assertTrue(releasable2.hasReferences());
assertEquals(InboundDecoder.END_CONTENT, endMarker);
}

Expand Down Expand Up @@ -311,7 +314,7 @@ public void testCompressedDecodeHandshakeCompatibility() throws IOException {
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
int bytesConsumed = decoder.decode(releasable1, fragments::add);
assertEquals(totalHeaderSize, bytesConsumed);
assertEquals(1, releasable1.refCount());
assertTrue(releasable1.hasReferences());

final Header header = (Header) fragments.get(0);
assertEquals(requestId, header.getRequestId());
Expand Down Expand Up @@ -339,16 +342,19 @@ public void testVersionIncompatibilityDecodeException() throws IOException {
Compression.Scheme.DEFLATE
);

final ReleasableBytesReference releasable1;
try (RecyclerBytesStreamOutput os = new RecyclerBytesStreamOutput(recycler)) {
final BytesReference bytes = message.serialize(os);

InboundDecoder decoder = new InboundDecoder(Version.CURRENT, recycler);
final ArrayList<Object> fragments = new ArrayList<>();
final ReleasableBytesReference releasable1 = ReleasableBytesReference.wrap(bytes);
expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add));
// No bytes are retained
assertEquals(1, releasable1.refCount());
try (ReleasableBytesReference r = ReleasableBytesReference.wrap(bytes)) {
releasable1 = r;
expectThrows(IllegalStateException.class, () -> decoder.decode(releasable1, fragments::add));
}
}
// No bytes are retained
assertFalse(releasable1.hasReferences());
}

public void testEnsureVersionCompatibility() throws IOException {
Expand Down
Loading

0 comments on commit 256521e

Please sign in to comment.