diff --git a/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java b/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java index 799f8eb7022..84a491b9e33 100644 --- a/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java +++ b/http-server-netty/src/main/java/io/micronaut/http/server/netty/HttpPipelineBuilder.java @@ -17,8 +17,8 @@ import io.micronaut.core.annotation.NonNull; import io.micronaut.core.naming.Named; -import io.micronaut.http.netty.AbstractNettyHttpRequest; import io.micronaut.http.context.event.HttpRequestReceivedEvent; +import io.micronaut.http.netty.AbstractNettyHttpRequest; import io.micronaut.http.netty.channel.ChannelPipelineCustomizer; import io.micronaut.http.netty.stream.HttpStreamsServerHandler; import io.micronaut.http.netty.stream.StreamingInboundHttp2ToHttpAdapter; @@ -32,6 +32,7 @@ import io.micronaut.http.ssl.ServerSslConfiguration; import io.netty.channel.Channel; import io.netty.channel.ChannelHandlerContext; +import io.netty.channel.ChannelInitializer; import io.netty.channel.ChannelOutboundHandler; import io.netty.channel.ChannelPipeline; import io.netty.channel.SimpleChannelInboundHandler; @@ -442,7 +443,14 @@ public void upgradeTo(ChannelHandlerContext ctx, FullHttpRequest upgradeRequest) server.getServerConfiguration().getMaxH2cUpgradeRequestSize() ); final CleartextHttp2ServerUpgradeHandler cleartextHttp2ServerUpgradeHandler = - new CleartextHttp2ServerUpgradeHandler(sourceCodec, upgradeHandler, connectionHandler); + new CleartextHttp2ServerUpgradeHandler(sourceCodec, upgradeHandler, new ChannelInitializer() { + @Override + protected void initChannel(@NonNull Channel ch) throws Exception { + ch.pipeline().addLast(connectionHandler); + insertHttp2DownstreamHandlers(); + onRequestPipelineBuilt(); + } + }); pipeline.addLast(cleartextHttp2ServerUpgradeHandler); pipeline.addLast(fallbackHandlerName, new SimpleChannelInboundHandler() { diff --git a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/http2/H2cSpec.groovy b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/http2/H2cSpec.groovy index 92df66cf4dc..6cb6f77ef02 100644 --- a/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/http2/H2cSpec.groovy +++ b/http-server-netty/src/test/groovy/io/micronaut/http/server/netty/http2/H2cSpec.groovy @@ -235,6 +235,72 @@ class H2cSpec extends Specification { content.release() } + def 'prior knowledge'() { + given: + def responseFuture = new CompletableFuture() + + def group = new NioEventLoopGroup(1) + def bootstrap = new Bootstrap() + .remoteAddress(embeddedServer.host, embeddedServer.port) + .group(group) + .channel(NioSocketChannel.class) + .handler(new ChannelInitializer() { + @Override + protected void initChannel(@NonNull SocketChannel ch) throws Exception { + def http2Connection = new DefaultHttp2Connection(false) + def inboundAdapter = new InboundHttp2ToHttpAdapterBuilder(http2Connection) + .maxContentLength(1000000) + .validateHttpHeaders(true) + .propagateSettings(true) + .build() + def connectionHandler = new HttpToHttp2ConnectionHandlerBuilder() + .connection(http2Connection) + .frameListener(new DelegatingDecompressorFrameListener(http2Connection, inboundAdapter)) + .build() + + ch.pipeline() + .addLast(connectionHandler) + .addLast(new ChannelInboundHandlerAdapter() { + @Override + void channelRead(@NonNull ChannelHandlerContext ctx, @NonNull Object msg) throws Exception { + ctx.read() + if (msg instanceof HttpMessage) { + if (msg.headers().getInt(HttpConversionUtil.ExtensionHeaderNames.STREAM_ID.text(), -1) != 3) { + responseFuture.completeExceptionally(new AssertionError("Response must be on stream 3")); + } + responseFuture.complete(ReferenceCountUtil.retain(msg)) + } + super.channelRead(ctx, msg) + } + + @Override + void exceptionCaught(ChannelHandlerContext ctx, Throwable cause) throws Exception { + super.exceptionCaught(ctx, cause) + cause.printStackTrace() + responseFuture.completeExceptionally(cause) + } + }) + + } + }) + + def channel = (SocketChannel) bootstrap.connect().await().channel() + + def request = new DefaultFullHttpRequest(HttpVersion.HTTP_1_1, HttpMethod.GET, '/h2c/test') + request.headers().set(HttpConversionUtil.ExtensionHeaderNames.SCHEME.text(), "http") + channel.writeAndFlush(request) + channel.read() + + expect: + def resp = responseFuture.get(10, TimeUnit.SECONDS) + resp != null + + cleanup: + channel.close() + resp.release() + group.shutdownGracefully() + } + @Controller("/h2c") static class TestController { @Get("/test")