From b114244898828e9fb254bea1f0bbdd24850b2f3f Mon Sep 17 00:00:00 2001 From: Yukiteru Li Date: Fri, 26 Jan 2024 05:52:07 +0800 Subject: [PATCH] feat(http1): support configurable `max_headers(num)` to client and server (#3523) --- Cargo.toml | 5 +- src/client/conn/http1.rs | 23 +++++ src/proto/h1/conn.rs | 7 ++ src/proto/h1/io.rs | 2 + src/proto/h1/mod.rs | 1 + src/proto/h1/role.rs | 181 +++++++++++++++++++++++++++++++++++++-- src/server/conn/http1.rs | 23 +++++ 7 files changed, 235 insertions(+), 7 deletions(-) diff --git a/Cargo.toml b/Cargo.toml index 2ac92f3133..43909950f7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -36,6 +36,7 @@ httpdate = { version = "1.0", optional = true } itoa = { version = "1", optional = true } libc = { version = "0.2", optional = true } pin-project-lite = { version = "0.2.4", optional = true } +smallvec = { version = "1.12", features = ["const_generics", "const_new"], optional = true } tracing = { version = "0.1", default-features = false, features = ["std"], optional = true } want = { version = "0.3", optional = true } @@ -80,8 +81,8 @@ http1 = ["dep:futures-channel", "dep:futures-util", "dep:httparse", "dep:itoa"] http2 = ["dep:futures-channel", "dep:futures-util", "dep:h2"] # Client/Server -client = ["dep:want", "dep:pin-project-lite"] -server = ["dep:httpdate", "dep:pin-project-lite"] +client = ["dep:want", "dep:pin-project-lite", "dep:smallvec"] +server = ["dep:httpdate", "dep:pin-project-lite", "dep:smallvec"] # C-API support (currently unstable (no semver)) ffi = ["dep:libc", "dep:http-body-util"] diff --git a/src/client/conn/http1.rs b/src/client/conn/http1.rs index b85f5d4b9e..569c470c10 100644 --- a/src/client/conn/http1.rs +++ b/src/client/conn/http1.rs @@ -115,6 +115,7 @@ pub struct Builder { h1_writev: Option, h1_title_case_headers: bool, h1_preserve_header_case: bool, + h1_max_headers: Option, #[cfg(feature = "ffi")] h1_preserve_header_order: bool, h1_read_buf_exact_size: Option, @@ -309,6 +310,7 @@ impl Builder { h1_parser_config: Default::default(), h1_title_case_headers: false, h1_preserve_header_case: false, + h1_max_headers: None, #[cfg(feature = "ffi")] h1_preserve_header_order: false, h1_max_buf_size: None, @@ -439,6 +441,24 @@ impl Builder { self } + /// Set the maximum number of headers. + /// + /// When a response is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If client receives more headers than the buffer size, the error "message header too large" + /// is returned. + /// + /// Note that headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each response, and there will be a performance drop of about 5%. + /// + /// Default is 100. + pub fn max_headers(&mut self, val: usize) -> &mut Self { + self.h1_max_headers = Some(val); + self + } + /// Set whether to support preserving original header order. /// /// Currently, this will record the order in which headers are received, and store this @@ -519,6 +539,9 @@ impl Builder { if opts.h1_preserve_header_case { conn.set_preserve_header_case(); } + if let Some(max_headers) = opts.h1_max_headers { + conn.set_http1_max_headers(max_headers); + } #[cfg(feature = "ffi")] if opts.h1_preserve_header_order { conn.set_preserve_header_order(); diff --git a/src/proto/h1/conn.rs b/src/proto/h1/conn.rs index 80fcaa6021..f880f97dcb 100644 --- a/src/proto/h1/conn.rs +++ b/src/proto/h1/conn.rs @@ -54,6 +54,7 @@ where keep_alive: KA::Busy, method: None, h1_parser_config: ParserConfig::default(), + h1_max_headers: None, #[cfg(feature = "server")] h1_header_read_timeout: None, #[cfg(feature = "server")] @@ -132,6 +133,10 @@ where self.state.h09_responses = true; } + pub(crate) fn set_http1_max_headers(&mut self, val: usize) { + self.state.h1_max_headers = Some(val); + } + #[cfg(feature = "server")] pub(crate) fn set_http1_header_read_timeout(&mut self, val: Duration) { self.state.h1_header_read_timeout = Some(val); @@ -207,6 +212,7 @@ where cached_headers: &mut self.state.cached_headers, req_method: &mut self.state.method, h1_parser_config: self.state.h1_parser_config.clone(), + h1_max_headers: self.state.h1_max_headers, #[cfg(feature = "server")] h1_header_read_timeout: self.state.h1_header_read_timeout, #[cfg(feature = "server")] @@ -847,6 +853,7 @@ struct State { /// a body or not. method: Option, h1_parser_config: ParserConfig, + h1_max_headers: Option, #[cfg(feature = "server")] h1_header_read_timeout: Option, #[cfg(feature = "server")] diff --git a/src/proto/h1/io.rs b/src/proto/h1/io.rs index 9f23ac16ff..5d009c1593 100644 --- a/src/proto/h1/io.rs +++ b/src/proto/h1/io.rs @@ -184,6 +184,7 @@ where cached_headers: parse_ctx.cached_headers, req_method: parse_ctx.req_method, h1_parser_config: parse_ctx.h1_parser_config.clone(), + h1_max_headers: parse_ctx.h1_max_headers, #[cfg(feature = "server")] h1_header_read_timeout: parse_ctx.h1_header_read_timeout, #[cfg(feature = "server")] @@ -725,6 +726,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, diff --git a/src/proto/h1/mod.rs b/src/proto/h1/mod.rs index 86561c3764..5b9872cddb 100644 --- a/src/proto/h1/mod.rs +++ b/src/proto/h1/mod.rs @@ -78,6 +78,7 @@ pub(crate) struct ParseContext<'a> { cached_headers: &'a mut Option, req_method: &'a mut Option, h1_parser_config: ParserConfig, + h1_max_headers: Option, #[cfg(feature = "server")] h1_header_read_timeout: Option, #[cfg(feature = "server")] diff --git a/src/proto/h1/role.rs b/src/proto/h1/role.rs index ede6fabc8f..62f65c76a2 100644 --- a/src/proto/h1/role.rs +++ b/src/proto/h1/role.rs @@ -13,6 +13,7 @@ use http::header::Entry; use http::header::ValueIter; use http::header::{self, HeaderMap, HeaderName, HeaderValue}; use http::{Method, StatusCode, Version}; +use smallvec::{smallvec, smallvec_inline, SmallVec}; use crate::body::DecodedLength; #[cfg(feature = "server")] @@ -29,7 +30,7 @@ use crate::proto::h1::{ use crate::proto::RequestHead; use crate::proto::{BodyLength, MessageHead, RequestLine}; -const MAX_HEADERS: usize = 100; +const DEFAULT_MAX_HEADERS: usize = 100; const AVERAGE_HEADER_SIZE: usize = 30; // totally scientific #[cfg(feature = "server")] const MAX_URI_LEN: usize = (u16::MAX - 1) as usize; @@ -139,9 +140,17 @@ impl Http1Transaction for Server { // but we *never* read any of it until after httparse has assigned // values into it. By not zeroing out the stack memory, this saves // a good ~5% on pipeline benchmarks. - let mut headers_indices = [MaybeUninit::::uninit(); MAX_HEADERS]; + let mut headers_indices: SmallVec<[MaybeUninit; DEFAULT_MAX_HEADERS]> = + match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; { - let mut headers = [MaybeUninit::>::uninit(); MAX_HEADERS]; + let mut headers: SmallVec<[MaybeUninit>; DEFAULT_MAX_HEADERS]> = + match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; trace!(bytes = buf.len(), "Request.parse"); let mut req = httparse::Request::new(&mut []); let bytes = buf.as_ref(); @@ -966,9 +975,18 @@ impl Http1Transaction for Client { // Loop to skip information status code headers (100 Continue, etc). loop { - let mut headers_indices = [MaybeUninit::::uninit(); MAX_HEADERS]; + let mut headers_indices: SmallVec<[MaybeUninit; DEFAULT_MAX_HEADERS]> = + match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; let (len, status, reason, version, headers_len) = { - let mut headers = [MaybeUninit::>::uninit(); MAX_HEADERS]; + let mut headers: SmallVec< + [MaybeUninit>; DEFAULT_MAX_HEADERS], + > = match ctx.h1_max_headers { + Some(cap) => smallvec![MaybeUninit::uninit(); cap], + None => smallvec_inline![MaybeUninit::uninit(); DEFAULT_MAX_HEADERS], + }; trace!(bytes = buf.len(), "Response.parse"); let mut res = httparse::Response::new(&mut []); let bytes = buf.as_ref(); @@ -1610,6 +1628,7 @@ mod tests { cached_headers: &mut None, req_method: &mut method, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1641,6 +1660,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1667,6 +1687,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1691,6 +1712,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1717,6 +1739,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1747,6 +1770,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config, + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1774,6 +1798,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(crate::Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1796,6 +1821,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1839,6 +1865,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -1863,6 +1890,7 @@ mod tests { cached_headers: &mut None, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -2096,6 +2124,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -2120,6 +2149,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(m), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -2144,6 +2174,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -2663,6 +2694,7 @@ mod tests { cached_headers: &mut None, req_method: &mut Some(Method::GET), h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -2681,6 +2713,143 @@ mod tests { assert_eq!(parsed.head.headers["server"], "hello\tworld"); } + #[test] + fn parse_too_large_headers() { + fn gen_req_with_headers(num: usize) -> String { + let mut req = String::from("GET / HTTP/1.1\r\n"); + for i in 0..num { + req.push_str(&format!("key{i}: val{i}\r\n")); + } + req.push_str("\r\n"); + req + } + fn gen_resp_with_headers(num: usize) -> String { + let mut req = String::from("HTTP/1.1 200 OK\r\n"); + for i in 0..num { + req.push_str(&format!("key{i}: val{i}\r\n")); + } + req.push_str("\r\n"); + req + } + fn parse(max_headers: Option, gen_size: usize, should_success: bool) { + { + // server side + let mut bytes = BytesMut::from(gen_req_with_headers(gen_size).as_str()); + let result = Server::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + h1_max_headers: max_headers, + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, + timer: Time::Empty, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + }, + ); + if should_success { + result.expect("parse ok").expect("parse complete"); + } else { + result.expect_err("parse should err"); + } + } + { + // client side + let mut bytes = BytesMut::from(gen_resp_with_headers(gen_size).as_str()); + let result = Client::parse( + &mut bytes, + ParseContext { + cached_headers: &mut None, + req_method: &mut None, + h1_parser_config: Default::default(), + h1_max_headers: max_headers, + h1_header_read_timeout: None, + h1_header_read_timeout_fut: &mut None, + h1_header_read_timeout_running: &mut false, + timer: Time::Empty, + preserve_header_case: false, + #[cfg(feature = "ffi")] + preserve_header_order: false, + h09_responses: false, + #[cfg(feature = "ffi")] + on_informational: &mut None, + }, + ); + if should_success { + result.expect("parse ok").expect("parse complete"); + } else { + result.expect_err("parse should err"); + } + } + } + + // check generator + assert_eq!( + gen_req_with_headers(0), + String::from("GET / HTTP/1.1\r\n\r\n") + ); + assert_eq!( + gen_req_with_headers(1), + String::from("GET / HTTP/1.1\r\nkey0: val0\r\n\r\n") + ); + assert_eq!( + gen_req_with_headers(2), + String::from("GET / HTTP/1.1\r\nkey0: val0\r\nkey1: val1\r\n\r\n") + ); + assert_eq!( + gen_req_with_headers(3), + String::from("GET / HTTP/1.1\r\nkey0: val0\r\nkey1: val1\r\nkey2: val2\r\n\r\n") + ); + + // default max_headers is 100, so + // + // - less than or equal to 100, accepted + // + parse(None, 0, true); + parse(None, 1, true); + parse(None, 50, true); + parse(None, 99, true); + parse(None, 100, true); + // + // - more than 100, rejected + // + parse(None, 101, false); + parse(None, 102, false); + parse(None, 200, false); + + // max_headers is 0, parser will reject any headers + // + // - without header, accepted + // + parse(Some(0), 0, true); + // + // - with header(s), rejected + // + parse(Some(0), 1, false); + parse(Some(0), 100, false); + + // max_headers is 200 + // + // - less than or equal to 200, accepted + // + parse(Some(200), 0, true); + parse(Some(200), 1, true); + parse(Some(200), 100, true); + parse(Some(200), 200, true); + // + // - more than 200, rejected + // + parse(Some(200), 201, false); + parse(Some(200), 210, false); + } + #[test] fn test_write_headers_orig_case_empty_value() { let mut headers = HeaderMap::new(); @@ -2751,6 +2920,7 @@ mod tests { cached_headers: &mut headers, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, @@ -2795,6 +2965,7 @@ mod tests { cached_headers: &mut headers, req_method: &mut None, h1_parser_config: Default::default(), + h1_max_headers: None, h1_header_read_timeout: None, h1_header_read_timeout_fut: &mut None, h1_header_read_timeout_running: &mut false, diff --git a/src/server/conn/http1.rs b/src/server/conn/http1.rs index e3eff1a02d..412a76154c 100644 --- a/src/server/conn/http1.rs +++ b/src/server/conn/http1.rs @@ -75,6 +75,7 @@ pub struct Builder { h1_keep_alive: bool, h1_title_case_headers: bool, h1_preserve_header_case: bool, + h1_max_headers: Option, h1_header_read_timeout: Dur, h1_writev: Option, max_buf_size: Option, @@ -242,6 +243,7 @@ impl Builder { h1_keep_alive: true, h1_title_case_headers: false, h1_preserve_header_case: false, + h1_max_headers: None, h1_header_read_timeout: Dur::Default(Some(Duration::from_secs(30))), h1_writev: None, max_buf_size: None, @@ -294,6 +296,24 @@ impl Builder { self } + /// Set the maximum number of headers. + /// + /// When a request is received, the parser will reserve a buffer to store headers for optimal + /// performance. + /// + /// If server receives more headers than the buffer size, it responds to the client with + /// "431 Request Header Fields Too Large". + /// + /// Note that headers is allocated on the stack by default, which has higher performance. After + /// setting this value, headers will be allocated in heap memory, that is, heap memory + /// allocation will occur for each request, and there will be a performance drop of about 5%. + /// + /// Default is 100. + pub fn max_headers(&mut self, val: usize) -> &mut Self { + self.h1_max_headers = Some(val); + self + } + /// Set a timeout for reading client request headers. If a client does not /// transmit the entire header within this time, the connection is closed. /// @@ -412,6 +432,9 @@ impl Builder { if self.h1_preserve_header_case { conn.set_preserve_header_case(); } + if let Some(max_headers) = self.h1_max_headers { + conn.set_http1_max_headers(max_headers); + } if let Some(dur) = self .timer .check(self.h1_header_read_timeout, "header_read_timeout")