diff --git a/src/client/client.rs b/src/client/client.rs index bf4db79fde..8163f54a23 100644 --- a/src/client/client.rs +++ b/src/client/client.rs @@ -16,6 +16,8 @@ use crate::common::{ exec::BoxSendFuture, lazy as hyper_lazy, sync_wrapper::SyncWrapper, task, Future, Lazy, Pin, Poll, }; +#[cfg(feature = "http2")] +use crate::ext::Protocol; use crate::rt::Executor; use super::conn; @@ -278,7 +280,13 @@ where origin_form(req.uri_mut()); } } else if req.method() == Method::CONNECT { + #[cfg(not(feature = "http2"))] authority_form(req.uri_mut()); + + #[cfg(feature = "http2")] + if req.extensions().get::().is_none() { + authority_form(req.uri_mut()); + } } let mut res = match pooled.send_request_retryable(req).await { diff --git a/tests/integration.rs b/tests/integration.rs index 2deee443f8..9e094cc713 100644 --- a/tests/integration.rs +++ b/tests/integration.rs @@ -305,6 +305,48 @@ t! { ; } +t! { + h2_connect_authority_form, + client: + request: + method: "CONNECT", + // http2 should strip scheme and path from URI (authority-form) + uri: "/connect_normal", + ; + response: + ; + server: + request: + method: "CONNECT", + // path should be stripped + uri: "", + ; + response: + ; +} + +t! { + h2_only; + h2_extended_connect_full_uri, + client: + request: + method: "CONNECT", + // http2 should not strip scheme and path from URI for extended CONNECT requests + uri: "/connect_extended", + protocol: "the-bread-protocol", + ; + response: + ; + server: + request: + method: "CONNECT", + // path should not be stripped + uri: "/connect_extended", + ; + response: + ; +} + t! { get_2, client: diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 6b3c8f4472..213f3e1562 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -13,14 +13,16 @@ use hyper::{Body, Client, Request, Response, Server, Version}; pub use futures_util::{ future, FutureExt as _, StreamExt as _, TryFutureExt as _, TryStreamExt as _, }; -pub use hyper::{HeaderMap, StatusCode}; +pub use hyper::{ext::Protocol, http::Extensions, HeaderMap, StatusCode}; pub use std::net::SocketAddr; #[allow(unused_macros)] macro_rules! t { ( + @impl $name:ident, - parallel: $range:expr + parallel: $range:expr, + $(h2_only: $_h2_only:expr)? ) => ( #[test] fn $name() { @@ -75,6 +77,7 @@ macro_rules! t { } ); ( + @impl $name:ident, client: $( request: $( @@ -91,7 +94,8 @@ macro_rules! t { response: $( $s_res_prop:ident: $s_res_val:tt, )*; - )* + )*, + h2_only: $h2_only:expr ) => ( #[test] fn $name() { @@ -116,15 +120,17 @@ macro_rules! t { } ),)*]; - __run_test(__TestConfig { - client_version: 1, - client_msgs: c.clone(), - server_version: 1, - server_msgs: s.clone(), - parallel: false, - connections: 1, - proxy: false, - }); + if !$h2_only { + __run_test(__TestConfig { + client_version: 1, + client_msgs: c.clone(), + server_version: 1, + server_msgs: s.clone(), + parallel: false, + connections: 1, + proxy: false, + }); + } __run_test(__TestConfig { client_version: 2, @@ -136,15 +142,17 @@ macro_rules! t { proxy: false, }); - __run_test(__TestConfig { - client_version: 1, - client_msgs: c.clone(), - server_version: 1, - server_msgs: s.clone(), - parallel: false, - connections: 1, - proxy: true, - }); + if !$h2_only { + __run_test(__TestConfig { + client_version: 1, + client_msgs: c.clone(), + server_version: 1, + server_msgs: s.clone(), + parallel: false, + connections: 1, + proxy: true, + }); + } __run_test(__TestConfig { client_version: 2, @@ -157,6 +165,12 @@ macro_rules! t { }); } ); + (h2_only; $($t:tt)*) => { + t!(@impl $($t)*, h2_only: true); + }; + ($($t:tt)*) => { + t!(@impl $($t)*, h2_only: false); + }; } macro_rules! __internal_map_prop { @@ -245,6 +259,7 @@ pub struct __CReq { pub uri: &'static str, pub headers: HeaderMap, pub body: Vec, + pub protocol: Option<&'static str>, } impl Default for __CReq { @@ -254,6 +269,7 @@ impl Default for __CReq { uri: "/", headers: HeaderMap::new(), body: Vec::new(), + protocol: None, } } } @@ -371,6 +387,7 @@ async fn async_test(cfg: __TestConfig) { let server = hyper::Server::bind(&SocketAddr::from(([127, 0, 0, 1], 0))) .http2_only(cfg.server_version == 2) + .http2_enable_connect_protocol() .serve(new_service); let mut addr = server.local_addr(); @@ -398,6 +415,9 @@ async fn async_test(cfg: __TestConfig) { //.headers(creq.headers) .body(creq.body.into()) .expect("Request::build"); + if let Some(protocol) = creq.protocol { + req.extensions_mut().insert(Protocol::from_static(protocol)); + } *req.headers_mut() = creq.headers; let cstatus = cres.status; let cheaders = cres.headers; @@ -458,18 +478,20 @@ fn naive_proxy(cfg: ProxyConfig) -> (SocketAddr, impl Future) { let max_connections = cfg.connections; let counter = AtomicUsize::new(0); - let srv = Server::bind(&([127, 0, 0, 1], 0).into()).serve(make_service_fn(move |_| { - let prev = counter.fetch_add(1, Ordering::Relaxed); - assert!(max_connections > prev, "proxy max connections"); - let client = client.clone(); - future::ok::<_, hyper::Error>(service_fn(move |mut req| { - let uri = format!("http://{}{}", dst_addr, req.uri().path()) - .parse() - .expect("proxy new uri parse"); - *req.uri_mut() = uri; - client.request(req) - })) - })); + let srv = Server::bind(&([127, 0, 0, 1], 0).into()) + .http2_enable_connect_protocol() + .serve(make_service_fn(move |_| { + let prev = counter.fetch_add(1, Ordering::Relaxed); + assert!(max_connections > prev, "proxy max connections"); + let client = client.clone(); + future::ok::<_, hyper::Error>(service_fn(move |mut req| { + let uri = format!("http://{}{}", dst_addr, req.uri().path()) + .parse() + .expect("proxy new uri parse"); + *req.uri_mut() = uri; + client.request(req) + })) + })); let proxy_addr = srv.local_addr(); (proxy_addr, srv.map(|res| res.expect("proxy error"))) }