Skip to content

Commit

Permalink
HTTP/2 CompressorHttp2ConnectionEncoder bug
Browse files Browse the repository at this point in the history
Motivation:
The CompressorHttp2ConnectionEncoder is attempting to attach a property to streams before the exist.

Modifications:
- Allow the super class to create the streams before attempting to attach a property to the stream.

Result:
CompressorHttp2ConnectionEncoder is able to set the property and access the compressor.
  • Loading branch information
Scottmitch committed Jul 17, 2015
1 parent b958263 commit d7cdc46
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 85 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -147,17 +147,43 @@ public ChannelFuture writeData(final ChannelHandlerContext ctx, final int stream
@Override
public ChannelFuture writeHeaders(ChannelHandlerContext ctx, int streamId, Http2Headers headers, int padding,
boolean endStream, ChannelPromise promise) {
initCompressor(streamId, headers, endStream);
return super.writeHeaders(ctx, streamId, headers, padding, endStream, promise);
try {
// Determine if compression is required and sanitize the headers.
EmbeddedChannel compressor = newCompressor(headers, endStream);

// Write the headers and create the stream object.
ChannelFuture future = super.writeHeaders(ctx, streamId, headers, padding, endStream, promise);

// After the stream object has been created, then attach the compressor as a property for data compression.
bindCompressorToStream(compressor, streamId);

return future;
} catch (Throwable e) {
promise.tryFailure(e);
}
return promise;
}

@Override
public ChannelFuture writeHeaders(final ChannelHandlerContext ctx, final int streamId, final Http2Headers headers,
final int streamDependency, final short weight, final boolean exclusive, final int padding,
final boolean endOfStream, final ChannelPromise promise) {
initCompressor(streamId, headers, endOfStream);
return super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive, padding, endOfStream,
promise);
try {
// Determine if compression is required and sanitize the headers.
EmbeddedChannel compressor = newCompressor(headers, endOfStream);

// Write the headers and create the stream object.
ChannelFuture future = super.writeHeaders(ctx, streamId, headers, streamDependency, weight, exclusive,
padding, endOfStream, promise);

// After the stream object has been created, then attach the compressor as a property for data compression.
bindCompressorToStream(compressor, streamId);

return future;
} catch (Throwable e) {
promise.tryFailure(e);
}
return promise;
}

/**
Expand Down Expand Up @@ -205,48 +231,50 @@ private EmbeddedChannel newCompressionChannel(ZlibWrapper wrapper) {
* Checks if a new compressor object is needed for the stream identified by {@code streamId}. This method will
* modify the {@code content-encoding} header contained in {@code headers}.
*
* @param streamId The identifier for the headers inside {@code headers}
* @param headers Object representing headers which are to be written
* @param endOfStream Indicates if the stream has ended
* @return The channel used to compress data.
* @throws Http2Exception if any problems occur during initialization.
*/
private void initCompressor(int streamId, Http2Headers headers, boolean endOfStream) {
final Http2Stream stream = connection().stream(streamId);
if (stream == null) {
return;
private EmbeddedChannel newCompressor(Http2Headers headers, boolean endOfStream) throws Http2Exception {
if (endOfStream) {
return null;
}

EmbeddedChannel compressor = stream.getProperty(propertyKey);
if (compressor == null) {
if (!endOfStream) {
ByteString encoding = headers.get(CONTENT_ENCODING);
if (encoding == null) {
encoding = IDENTITY;
}
try {
compressor = newContentCompressor(encoding);
if (compressor != null) {
stream.setProperty(propertyKey, compressor);
ByteString targetContentEncoding = getTargetContentEncoding(encoding);
if (IDENTITY.equals(targetContentEncoding)) {
headers.remove(CONTENT_ENCODING);
} else {
headers.set(CONTENT_ENCODING, targetContentEncoding);
}
}
} catch (Throwable ignored) {
// Ignore
}
}
} else if (endOfStream) {
cleanup(stream, compressor);
ByteString encoding = headers.get(CONTENT_ENCODING);
if (encoding == null) {
encoding = IDENTITY;
}

final EmbeddedChannel compressor = newContentCompressor(encoding);
if (compressor != null) {
ByteString targetContentEncoding = getTargetContentEncoding(encoding);
if (IDENTITY.equals(targetContentEncoding)) {
headers.remove(CONTENT_ENCODING);
} else {
headers.set(CONTENT_ENCODING, targetContentEncoding);
}

// The content length will be for the decompressed data. Since we will compress the data
// this content-length will not be correct. Instead of queuing messages or delaying sending
// header frames...just remove the content-length header
headers.remove(CONTENT_LENGTH);
}

return compressor;
}

/**
* Called after the super class has written the headers and created any associated stream objects.
* @param compressor The compressor associated with the stream identified by {@code streamId}.
* @param streamId The stream id for which the headers were written.
*/
private void bindCompressorToStream(EmbeddedChannel compressor, int streamId) {
if (compressor != null) {
Http2Stream stream = connection().stream(streamId);
if (stream != null) {
stream.setProperty(propertyKey, compressor);
}
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -43,8 +43,6 @@
import io.netty.handler.codec.http.HttpHeaderNames;
import io.netty.handler.codec.http.HttpHeaderValues;
import io.netty.handler.codec.http2.Http2Stream.State;
import io.netty.handler.codec.http2.Http2TestUtil.FrameAdapter;
import io.netty.handler.codec.http2.Http2TestUtil.FrameCountDown;
import io.netty.handler.codec.http2.Http2TestUtil.Http2Runnable;
import io.netty.util.AsciiString;
import io.netty.util.CharsetUtil;
Expand Down Expand Up @@ -83,9 +81,7 @@ public class DataCompressionHttp2Test {
private Bootstrap cb;
private Channel serverChannel;
private Channel clientChannel;
private CountDownLatch serverLatch;
private CountDownLatch clientLatch;
private CountDownLatch clientSettingsAckLatch;
private CountDownLatch serverCloseLatch;
private Http2Connection serverConnection;
private Http2Connection clientConnection;
private Http2ConnectionHandler clientHandler;
Expand Down Expand Up @@ -114,14 +110,10 @@ public void teardown() throws InterruptedException {

@Test
public void justHeadersNoData() throws Exception {
bootstrapEnv(1, 1, 0, 1);
bootstrapEnv(0);
final Http2Headers headers = new DefaultHttp2Headers().method(GET).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);

// Required because the decompressor intercepts the onXXXRead events before
// our {@link Http2TestUtil$FrameAdapter} does.
FrameAdapter.getOrCreateStream(serverConnection, 3, false);
FrameAdapter.getOrCreateStream(clientConnection, 3, false);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
Expand All @@ -138,15 +130,11 @@ public void run() throws Http2Exception {
public void gzipEncodingSingleEmptyMessage() throws Exception {
final String text = "";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
bootstrapEnv(1, 1, data.readableBytes(), 1);
bootstrapEnv(data.readableBytes());
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);

// Required because the decompressor intercepts the onXXXRead events before
// our {@link Http2TestUtil$FrameAdapter} does.
Http2Stream stream = FrameAdapter.getOrCreateStream(serverConnection, 3, false);
FrameAdapter.getOrCreateStream(clientConnection, 3, false);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
Expand All @@ -156,7 +144,6 @@ public void run() throws Http2Exception {
}
});
awaitServer();
assertEquals(0, serverConnection.local().flowController().unconsumedBytes(stream));
assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data.release();
Expand All @@ -167,15 +154,11 @@ public void run() throws Http2Exception {
public void gzipEncodingSingleMessage() throws Exception {
final String text = "aaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaaabbbbbbbbbbbbbbbbbbbbbbbbbbbbbccccccccccccccccccccccc";
final ByteBuf data = Unpooled.copiedBuffer(text.getBytes());
bootstrapEnv(1, 1, data.readableBytes(), 1);
bootstrapEnv(data.readableBytes());
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);

// Required because the decompressor intercepts the onXXXRead events before
// our {@link Http2TestUtil$FrameAdapter} does.
Http2Stream stream = FrameAdapter.getOrCreateStream(serverConnection, 3, false);
FrameAdapter.getOrCreateStream(clientConnection, 3, false);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
Expand All @@ -185,7 +168,6 @@ public void run() throws Http2Exception {
}
});
awaitServer();
assertEquals(0, serverConnection.local().flowController().unconsumedBytes(stream));
assertEquals(text, serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data.release();
Expand All @@ -198,16 +180,12 @@ public void gzipEncodingMultipleMessages() throws Exception {
final String text2 = "dddddddddddddddddddeeeeeeeeeeeeeeeeeeeffffffffffffffffffff";
final ByteBuf data1 = Unpooled.copiedBuffer(text1.getBytes());
final ByteBuf data2 = Unpooled.copiedBuffer(text2.getBytes());
bootstrapEnv(1, 1, data1.readableBytes() + data2.readableBytes(), 1);
bootstrapEnv(data1.readableBytes() + data2.readableBytes());
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.GZIP);

// Required because the decompressor intercepts the onXXXRead events before
// our {@link Http2TestUtil$FrameAdapter} does.
Http2Stream stream = FrameAdapter.getOrCreateStream(serverConnection, 3, false);
FrameAdapter.getOrCreateStream(clientConnection, 3, false);
runInChannel(clientChannel, new Http2Runnable() {
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
clientEncoder.writeHeaders(ctxClient(), 3, headers, 0, false, newPromiseClient());
Expand All @@ -217,9 +195,7 @@ public void run() throws Http2Exception {
}
});
awaitServer();
assertEquals(0, serverConnection.local().flowController().unconsumedBytes(stream));
assertEquals(text1 + text2,
serverOut.toString(CharsetUtil.UTF_8.name()));
assertEquals(text1 + text2, serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data1.release();
data2.release();
Expand All @@ -231,16 +207,12 @@ public void deflateEncodingWriteLargeMessage() throws Exception {
final int BUFFER_SIZE = 1 << 12;
final byte[] bytes = new byte[BUFFER_SIZE];
new Random().nextBytes(bytes);
bootstrapEnv(1, 1, BUFFER_SIZE, 1);
bootstrapEnv(BUFFER_SIZE);
final ByteBuf data = Unpooled.wrappedBuffer(bytes);
try {
final Http2Headers headers = new DefaultHttp2Headers().method(POST).path(PATH)
.set(HttpHeaderNames.CONTENT_ENCODING, HttpHeaderValues.DEFLATE);

// Required because the decompressor intercepts the onXXXRead events before
// our {@link Http2TestUtil$FrameAdapter} does.
Http2Stream stream = FrameAdapter.getOrCreateStream(serverConnection, 3, false);
FrameAdapter.getOrCreateStream(clientConnection, 3, false);
runInChannel(clientChannel, new Http2Runnable() {
@Override
public void run() throws Http2Exception {
Expand All @@ -250,20 +222,16 @@ public void run() throws Http2Exception {
}
});
awaitServer();
assertEquals(0, serverConnection.local().flowController().unconsumedBytes(stream));
assertEquals(data.resetReaderIndex().toString(CharsetUtil.UTF_8),
serverOut.toString(CharsetUtil.UTF_8.name()));
} finally {
data.release();
}
}

private void bootstrapEnv(int serverHalfClosedCount, int clientSettingsAckLatchCount,
int serverOutSize, int clientCount) throws Exception {
private void bootstrapEnv(int serverOutSize) throws Exception {
serverOut = new ByteArrayOutputStream(serverOutSize);
serverLatch = new CountDownLatch(serverHalfClosedCount);
clientLatch = new CountDownLatch(clientCount);
clientSettingsAckLatch = new CountDownLatch(clientSettingsAckLatchCount);
serverCloseLatch = new CountDownLatch(1);
sb = new ServerBootstrap();
cb = new Bootstrap();

Expand All @@ -275,12 +243,12 @@ private void bootstrapEnv(int serverHalfClosedCount, int clientSettingsAckLatchC
@Override
public void onStreamActive(Http2Stream stream) {
if (stream.state() == State.HALF_CLOSED_LOCAL || stream.state() == State.HALF_CLOSED_REMOTE) {
serverLatch.countDown();
serverCloseLatch.countDown();
}
}
@Override
public void onStreamHalfClosed(Http2Stream stream) {
serverLatch.countDown();
serverCloseLatch.countDown();
}
});

Expand Down Expand Up @@ -322,15 +290,13 @@ protected void initChannel(Channel ch) throws Exception {
@Override
protected void initChannel(Channel ch) throws Exception {
ChannelPipeline p = ch.pipeline();
FrameCountDown clientFrameCountDown = new FrameCountDown(clientListener,
clientSettingsAckLatch, clientLatch);
clientEncoder = new CompressorHttp2ConnectionEncoder(
new DefaultHttp2ConnectionEncoder(clientConnection, new DefaultHttp2FrameWriter()));
Http2ConnectionDecoder decoder =
new DefaultHttp2ConnectionDecoder(clientConnection, clientEncoder,
new DefaultHttp2FrameReader(),
new DelegatingDecompressorFrameListener(clientConnection,
clientFrameCountDown));
clientListener));
clientHandler = new Http2ConnectionHandler(decoder, clientEncoder);
p.addLast(clientHandler);
}
Expand All @@ -346,8 +312,7 @@ protected void initChannel(Channel ch) throws Exception {
}

private void awaitServer() throws Exception {
assertTrue(clientSettingsAckLatch.await(5, SECONDS));
assertTrue(serverLatch.await(5, SECONDS));
assertTrue(serverCloseLatch.await(5, SECONDS));
serverOut.flush();
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -114,7 +114,7 @@ static class FrameAdapter extends ByteToMessageDecoder {
this.latch = latch;
}

public Http2Stream getOrCreateStream(int streamId, boolean halfClosed) throws Http2Exception {
private Http2Stream getOrCreateStream(int streamId, boolean halfClosed) throws Http2Exception {
return getOrCreateStream(connection, streamId, halfClosed);
}

Expand Down

0 comments on commit d7cdc46

Please sign in to comment.