Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prevent ByteToMessageDecoder from overreading when !isAutoRead #9252

Merged
merged 3 commits into from Jun 28, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Expand Up @@ -21,6 +21,7 @@
import io.netty.buffer.ByteBufAllocator;
import io.netty.buffer.CompositeByteBuf;
import io.netty.buffer.Unpooled;
import io.netty.channel.ChannelConfig;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.socket.ChannelInputShutdownEvent;
Expand Down Expand Up @@ -151,8 +152,14 @@ public ByteBuf cumulate(ByteBufAllocator alloc, ByteBuf cumulation, ByteBuf in)
ByteBuf cumulation;
private Cumulator cumulator = MERGE_CUMULATOR;
private boolean singleDecode;
private boolean decodeWasNull;
private boolean first;

/**
* This flag is used to determine if we need to call {@link ChannelHandlerContext#read()} to consume more data
* when {@link ChannelConfig#isAutoRead()} is {@code false}.
*/
private boolean firedChannelRead;

/**
* A bitmask where the bits are defined as
* <ul>
Expand Down Expand Up @@ -291,7 +298,7 @@ public void channelRead(ChannelHandlerContext ctx, Object msg) throws Exception
}

int size = out.size();
decodeWasNull = !out.insertSinceRecycled();
firedChannelRead |= out.insertSinceRecycled();
fireChannelRead(ctx, out, size);
out.recycle();
}
Expand Down Expand Up @@ -326,12 +333,10 @@ static void fireChannelRead(ChannelHandlerContext ctx, CodecOutputList msgs, int
public void channelReadComplete(ChannelHandlerContext ctx) throws Exception {
numReads = 0;
discardSomeReadBytes();
if (decodeWasNull) {
decodeWasNull = false;
if (!ctx.channel().config().isAutoRead()) {
ctx.read();
}
if (!firedChannelRead && !ctx.channel().config().isAutoRead()) {
ctx.read();
}
firedChannelRead = false;
ctx.fireChannelReadComplete();
}

Expand Down
Expand Up @@ -22,6 +22,7 @@
import io.netty.buffer.UnpooledHeapByteBuf;
import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter;
import io.netty.channel.ChannelOutboundHandlerAdapter;
import io.netty.channel.embedded.EmbeddedChannel;
import io.netty.util.internal.PlatformDependent;
import org.junit.Test;
Expand All @@ -30,6 +31,7 @@
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.LinkedBlockingDeque;

import static io.netty.buffer.Unpooled.wrappedBuffer;
import static org.junit.Assert.*;

public class ByteToMessageDecoderTest {
Expand Down Expand Up @@ -345,4 +347,56 @@ public CompositeByteBuf addComponent(boolean increaseWriterIndex, ByteBuf buffer
assertEquals(0, in.refCnt());
}
}

@Test
public void testDoesNotOverRead() {
class ReadInterceptingHandler extends ChannelOutboundHandlerAdapter {
private int readsTriggered;

@Override
public void read(ChannelHandlerContext ctx) throws Exception {
readsTriggered++;
super.read(ctx);
}
}
ReadInterceptingHandler interceptor = new ReadInterceptingHandler();

EmbeddedChannel channel = new EmbeddedChannel();
channel.config().setAutoRead(false);
channel.pipeline().addLast(interceptor, new FixedLengthFrameDecoder(3));
assertEquals(0, interceptor.readsTriggered);

// 0 complete frames, 1 partial frame: SHOULD trigger a read
channel.writeInbound(wrappedBuffer(new byte[] { 0, 1 }));
assertEquals(1, interceptor.readsTriggered);

// 2 complete frames, 0 partial frames: should NOT trigger a read
channel.writeInbound(wrappedBuffer(new byte[] { 2 }), wrappedBuffer(new byte[] { 3, 4, 5 }));
assertEquals(1, interceptor.readsTriggered);

// 1 complete frame, 1 partial frame: should NOT trigger a read
channel.writeInbound(wrappedBuffer(new byte[] { 6, 7, 8 }), wrappedBuffer(new byte[] { 9 }));
assertEquals(1, interceptor.readsTriggered);

// 1 complete frame, 1 partial frame: should NOT trigger a read
channel.writeInbound(wrappedBuffer(new byte[] { 10, 11 }), wrappedBuffer(new byte[] { 12 }));
assertEquals(1, interceptor.readsTriggered);

// 0 complete frames, 1 partial frame: SHOULD trigger a read
channel.writeInbound(wrappedBuffer(new byte[] { 13 }));
assertEquals(2, interceptor.readsTriggered);

// 1 complete frame, 0 partial frames: should NOT trigger a read
channel.writeInbound(wrappedBuffer(new byte[] { 14 }));
assertEquals(2, interceptor.readsTriggered);

for (int i = 0; i < 5; i++) {
ByteBuf read = channel.readInbound();
assertEquals(i * 3 + 0, read.getByte(0));
assertEquals(i * 3 + 1, read.getByte(1));
assertEquals(i * 3 + 2, read.getByte(2));
iamaleksey marked this conversation as resolved.
Show resolved Hide resolved
read.release();
}
assertFalse(channel.finish());
}
}