Skip to content

Commit

Permalink
Handle client-disabled server push (#486)
Browse files Browse the repository at this point in the history
  • Loading branch information
eggyal committed Sep 18, 2020
1 parent a193237 commit 2b19acf
Show file tree
Hide file tree
Showing 6 changed files with 70 additions and 3 deletions.
4 changes: 4 additions & 0 deletions src/codec/error.rs
Expand Up @@ -63,6 +63,9 @@ pub enum UserError {

/// Tries to update local SETTINGS while ACK has not been received.
SendSettingsWhilePending,

/// Tries to send push promise to peer who has disabled server push
PeerDisabledServerPush,
}

// ===== impl RecvError =====
Expand Down Expand Up @@ -136,6 +139,7 @@ impl fmt::Display for UserError {
PollResetAfterSendResponse => "poll_reset after send_response is illegal",
SendPingWhilePending => "send_ping before received previous pong",
SendSettingsWhilePending => "sending SETTINGS before received previous ACK",
PeerDisabledServerPush => "sending PUSH_PROMISE to peer who disabled server push",
})
}
}
4 changes: 2 additions & 2 deletions src/frame/settings.rs
Expand Up @@ -99,8 +99,8 @@ impl Settings {
self.max_header_list_size = size;
}

pub fn is_push_enabled(&self) -> bool {
self.enable_push.unwrap_or(1) != 0
pub fn is_push_enabled(&self) -> Option<bool> {
self.enable_push.map(|val| val != 0)
}

pub fn set_enable_push(&mut self, enable: bool) {
Expand Down
2 changes: 1 addition & 1 deletion src/proto/connection.rs
Expand Up @@ -86,7 +86,7 @@ where
.unwrap_or(DEFAULT_INITIAL_WINDOW_SIZE),
initial_max_send_streams: config.initial_max_send_streams,
local_next_stream_id: config.next_stream_id,
local_push_enabled: config.settings.is_push_enabled(),
local_push_enabled: config.settings.is_push_enabled().unwrap_or(true),
local_reset_duration: config.reset_stream_duration,
local_reset_max: config.reset_stream_max,
remote_init_window_sz: DEFAULT_INITIAL_WINDOW_SIZE,
Expand Down
11 changes: 11 additions & 0 deletions src/proto/streams/send.rs
Expand Up @@ -32,6 +32,8 @@ pub(super) struct Send {

/// Prioritization layer
prioritize: Prioritize,

is_push_enabled: bool,
}

/// A value to detect which public API has called `poll_reset`.
Expand All @@ -49,6 +51,7 @@ impl Send {
max_stream_id: StreamId::MAX,
next_stream_id: Ok(config.local_next_stream_id),
prioritize: Prioritize::new(config),
is_push_enabled: true,
}
}

Expand Down Expand Up @@ -95,6 +98,10 @@ impl Send {
stream: &mut store::Ptr,
task: &mut Option<Waker>,
) -> Result<(), UserError> {
if !self.is_push_enabled {
return Err(UserError::PeerDisabledServerPush);
}

tracing::trace!(
"send_push_promise; frame={:?}; init_window={:?}",
frame,
Expand Down Expand Up @@ -496,6 +503,10 @@ impl Send {
}
}

if let Some(val) = settings.is_push_enabled() {
self.is_push_enabled = val
}

Ok(())
}

Expand Down
5 changes: 5 additions & 0 deletions tests/h2-support/src/frames.rs
Expand Up @@ -339,6 +339,11 @@ impl Mock<frame::Settings> {
self.0.set_max_header_list_size(Some(val));
self
}

pub fn disable_push(mut self) -> Self {
self.0.set_enable_push(false);
self
}
}

impl From<Mock<frame::Settings>> for frame::Settings {
Expand Down
47 changes: 47 additions & 0 deletions tests/h2-tests/tests/server.rs
Expand Up @@ -220,6 +220,53 @@ async fn push_request() {
join(client, srv).await;
}

#[tokio::test]
async fn push_request_disabled() {
h2_support::trace_init!();
let (io, mut client) = mock::new();

let client = async move {
client
.assert_server_handshake_with_settings(frames::settings().disable_push())
.await;
client
.send_frame(
frames::headers(1)
.request("GET", "https://example.com/")
.eos(),
)
.await;
client
.recv_frame(frames::headers(1).response(200).eos())
.await;
};

let srv = async move {
let mut srv = server::handshake(io).await.expect("handshake");
let (req, mut stream) = srv.next().await.unwrap().unwrap();

assert_eq!(req.method(), &http::Method::GET);

// attempt to push - expect failure
let req = http::Request::builder()
.method("GET")
.uri("https://http2.akamai.com/style.css")
.body(())
.unwrap();
stream
.push_request(req)
.expect_err("push_request should error");

// send normal response
let rsp = http::Response::builder().status(200).body(()).unwrap();
stream.send_response(rsp, true).unwrap();

assert!(srv.next().await.is_none());
};

join(client, srv).await;
}

#[tokio::test]
async fn push_request_against_concurrency() {
h2_support::trace_init!();
Expand Down

0 comments on commit 2b19acf

Please sign in to comment.