From 2b19acf1323257e1d2a99ad9889e0619593561fd Mon Sep 17 00:00:00 2001 From: eggyal Date: Fri, 18 Sep 2020 01:25:31 +0100 Subject: [PATCH] Handle client-disabled server push (#486) --- src/codec/error.rs | 4 +++ src/frame/settings.rs | 4 +-- src/proto/connection.rs | 2 +- src/proto/streams/send.rs | 11 ++++++++ tests/h2-support/src/frames.rs | 5 ++++ tests/h2-tests/tests/server.rs | 47 ++++++++++++++++++++++++++++++++++ 6 files changed, 70 insertions(+), 3 deletions(-) diff --git a/src/codec/error.rs b/src/codec/error.rs index 2c6b2961d..5d6659223 100644 --- a/src/codec/error.rs +++ b/src/codec/error.rs @@ -63,6 +63,9 @@ pub enum UserError { /// Tries to update local SETTINGS while ACK has not been received. SendSettingsWhilePending, + + /// Tries to send push promise to peer who has disabled server push + PeerDisabledServerPush, } // ===== impl RecvError ===== @@ -136,6 +139,7 @@ impl fmt::Display for UserError { PollResetAfterSendResponse => "poll_reset after send_response is illegal", SendPingWhilePending => "send_ping before received previous pong", SendSettingsWhilePending => "sending SETTINGS before received previous ACK", + PeerDisabledServerPush => "sending PUSH_PROMISE to peer who disabled server push", }) } } diff --git a/src/frame/settings.rs b/src/frame/settings.rs index 06de9cf12..523f20b06 100644 --- a/src/frame/settings.rs +++ b/src/frame/settings.rs @@ -99,8 +99,8 @@ impl Settings { self.max_header_list_size = size; } - pub fn is_push_enabled(&self) -> bool { - self.enable_push.unwrap_or(1) != 0 + pub fn is_push_enabled(&self) -> Option { + self.enable_push.map(|val| val != 0) } pub fn set_enable_push(&mut self, enable: bool) { diff --git a/src/proto/connection.rs b/src/proto/connection.rs index ffa2945c6..1c1c8ce1b 100644 --- a/src/proto/connection.rs +++ b/src/proto/connection.rs @@ -86,7 +86,7 @@ where .unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE), initial_max_send_streams: config.initial_max_send_streams, local_next_stream_id: config.next_stream_id, - local_push_enabled: config.settings.is_push_enabled(), + local_push_enabled: config.settings.is_push_enabled().unwrap_or(true), local_reset_duration: config.reset_stream_duration, local_reset_max: config.reset_stream_max, remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE, diff --git a/src/proto/streams/send.rs b/src/proto/streams/send.rs index 220a8b461..10934de48 100644 --- a/src/proto/streams/send.rs +++ b/src/proto/streams/send.rs @@ -32,6 +32,8 @@ pub(super) struct Send { /// Prioritization layer prioritize: Prioritize, + + is_push_enabled: bool, } /// A value to detect which public API has called `poll_reset`. @@ -49,6 +51,7 @@ impl Send { max_stream_id: StreamId::MAX, next_stream_id: Ok(config.local_next_stream_id), prioritize: Prioritize::new(config), + is_push_enabled: true, } } @@ -95,6 +98,10 @@ impl Send { stream: &mut store::Ptr, task: &mut Option, ) -> Result<(), UserError> { + if !self.is_push_enabled { + return Err(UserError::PeerDisabledServerPush); + } + tracing::trace!( "send_push_promise; frame={:?}; init_window={:?}", frame, @@ -496,6 +503,10 @@ impl Send { } } + if let Some(val) = settings.is_push_enabled() { + self.is_push_enabled = val + } + Ok(()) } diff --git a/tests/h2-support/src/frames.rs b/tests/h2-support/src/frames.rs index b9393b2b5..05fb3202f 100644 --- a/tests/h2-support/src/frames.rs +++ b/tests/h2-support/src/frames.rs @@ -339,6 +339,11 @@ impl Mock { self.0.set_max_header_list_size(Some(val)); self } + + pub fn disable_push(mut self) -> Self { + self.0.set_enable_push(false); + self + } } impl From> for frame::Settings { diff --git a/tests/h2-tests/tests/server.rs b/tests/h2-tests/tests/server.rs index 3a7649135..4be70902b 100644 --- a/tests/h2-tests/tests/server.rs +++ b/tests/h2-tests/tests/server.rs @@ -220,6 +220,53 @@ async fn push_request() { join(client, srv).await; } +#[tokio::test] +async fn push_request_disabled() { + h2_support::trace_init!(); + let (io, mut client) = mock::new(); + + let client = async move { + client + .assert_server_handshake_with_settings(frames::settings().disable_push()) + .await; + client + .send_frame( + frames::headers(1) + .request("GET", "https://example.com/") + .eos(), + ) + .await; + client + .recv_frame(frames::headers(1).response(200).eos()) + .await; + }; + + let srv = async move { + let mut srv = server::handshake(io).await.expect("handshake"); + let (req, mut stream) = srv.next().await.unwrap().unwrap(); + + assert_eq!(req.method(), &http::Method::GET); + + // attempt to push - expect failure + let req = http::Request::builder() + .method("GET") + .uri("https://http2.akamai.com/style.css") + .body(()) + .unwrap(); + stream + .push_request(req) + .expect_err("push_request should error"); + + // send normal response + let rsp = http::Response::builder().status(200).body(()).unwrap(); + stream.send_response(rsp, true).unwrap(); + + assert!(srv.next().await.is_none()); + }; + + join(client, srv).await; +} + #[tokio::test] async fn push_request_against_concurrency() { h2_support::trace_init!();