From 11159bb913bac8004609855195febee28819cdd2 Mon Sep 17 00:00:00 2001 From: Sean McArthur Date: Tue, 10 Jun 2025 15:27:39 -0400 Subject: [PATCH] feat(pool): add a Negotiate pooling service --- Cargo.toml | 3 +- src/client/pool/mod.rs | 1 + src/client/pool/negotiate.rs | 574 +++++++++++++++++++++++++++++++++++ 3 files changed, 577 insertions(+), 1 deletion(-) create mode 100644 src/client/pool/negotiate.rs diff --git a/Cargo.toml b/Cargo.toml index 130e48f..f793258 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -33,6 +33,7 @@ pin-project-lite = "0.2.4" socket2 = { version = ">=0.5.9, <0.7", optional = true, features = ["all"] } tracing = { version = "0.1", default-features = false, features = ["std"], optional = true } tokio = { version = "1", optional = true, default-features = false } +tower-layer = { version = "0.3", optional = true } tower-service = { version = "0.3", optional = true } [dev-dependencies] @@ -76,7 +77,7 @@ full = [ client = ["hyper/client", "tokio/net", "dep:tracing", "dep:futures-channel", "dep:tower-service"] client-legacy = ["client", "dep:socket2", "tokio/sync", "dep:libc", "dep:futures-util"] -client-pool = ["dep:futures-util"] +client-pool = ["dep:futures-util", "dep:tower-layer"] client-proxy = ["client", "dep:base64", "dep:ipnet", "dep:percent-encoding"] client-proxy-system = ["dep:system-configuration", "dep:windows-registry"] diff --git a/src/client/pool/mod.rs b/src/client/pool/mod.rs index a17acd3..8b2c3bf 100644 --- a/src/client/pool/mod.rs +++ b/src/client/pool/mod.rs @@ -1,4 +1,5 @@ //! Composable pool services pub mod cache; +pub mod negotiate; pub mod singleton; diff --git a/src/client/pool/negotiate.rs b/src/client/pool/negotiate.rs new file mode 100644 index 0000000..8e28318 --- /dev/null +++ b/src/client/pool/negotiate.rs @@ -0,0 +1,574 @@ +//! Negotiate a pool of services +//! +//! The negotiate pool allows for a service that can decide between two service +//! types based on an intermediate return value. It differs from typical +//! routing since it doesn't depend on the request, but the response. +//! +//! The original use case is support ALPN upgrades to HTTP/2, with a fallback +//! to HTTP/1. +//! +//! # Example +//! +//! ```rust,ignore +//! # async fn run() -> Result<(), Box> { +//! # struct Conn; +//! # impl Conn { fn negotiated_protocol(&self) -> &[u8] { b"h2" } } +//! # let some_tls_connector = tower::service::service_fn(|_| async move { +//! # Ok::<_, std::convert::Infallible>(Conn) +//! # }); +//! # let http1_layer = tower::layer::layer_fn(|s| s); +//! # let http2_layer = tower::layer::layer_fn(|s| s); +//! let mut pool = hyper_util::client::pool::negotiate::builder() +//! .connect(some_tls_connector) +//! .inspect(|c| c.negotiated_protocol() == b"h2") +//! .fallback(http1_layer) +//! .upgrade(http2_layer) +//! .build(); +//! +//! // connect +//! let mut svc = pool.call(http::Uri::from_static("https://hyper.rs")).await?; +//! svc.ready().await; +//! +//! // http1 or http2 is now set up +//! # let some_http_req = http::Request::new(()); +//! let resp = svc.call(some_http_req).await?; +//! # Ok(()) +//! # } +//! ``` + +pub use self::internal::builder; + +#[cfg(docsrs)] +pub use self::internal::Builder; +#[cfg(docsrs)] +pub use self::internal::Negotiate; +#[cfg(docsrs)] +pub use self::internal::Negotiated; + +mod internal { + use std::future::Future; + use std::pin::Pin; + use std::sync::{Arc, Mutex}; + use std::task::{self, Poll}; + + use futures_core::ready; + use pin_project_lite::pin_project; + use tower_layer::Layer; + use tower_service::Service; + + type BoxError = Box; + + /// A negotiating pool over an inner make service. + /// + /// Created with [`builder()`]. + /// + /// # Unnameable + /// + /// This type is normally unnameable, forbidding naming of the type within + /// code. The type is exposed in the documentation to show which methods + /// can be publicly called. + pub struct Negotiate { + left: L, + right: R, + } + + /// A negotiated service returned by [`Negotiate`]. + /// + /// # Unnameable + /// + /// This type is normally unnameable, forbidding naming of the type within + /// code. The type is exposed in the documentation to show which methods + /// can be publicly called. + #[derive(Clone, Debug)] + pub enum Negotiated { + #[doc(hidden)] + Fallback(L), + #[doc(hidden)] + Upgraded(R), + } + + pin_project! { + pub struct Negotiating + where + L: Service, + R: Service<()>, + { + #[pin] + state: State, + left: L, + right: R, + } + } + + pin_project! { + #[project = StateProj] + enum State { + Eager { + #[pin] + future: FR, + dst: Option, + }, + Fallback { + #[pin] + future: FL, + }, + Upgrade { + #[pin] + future: FR, + } + } + } + + pin_project! { + #[project = NegotiatedProj] + pub enum NegotiatedFuture { + Fallback { + #[pin] + future: L + }, + Upgraded { + #[pin] + future: R + }, + } + } + + /// A builder to configure a `Negotiate`. + /// + /// # Unnameable + /// + /// This type is normally unnameable, forbidding naming of the type within + /// code. The type is exposed in the documentation to show which methods + /// can be publicly called. + #[derive(Debug)] + pub struct Builder { + connect: C, + inspect: I, + fallback: L, + upgrade: R, + } + + #[derive(Debug)] + pub struct WantsConnect; + #[derive(Debug)] + pub struct WantsInspect; + #[derive(Debug)] + pub struct WantsFallback; + #[derive(Debug)] + pub struct WantsUpgrade; + + /// Start a builder to construct a `Negotiate` pool. + pub fn builder() -> Builder { + Builder { + connect: WantsConnect, + inspect: WantsInspect, + fallback: WantsFallback, + upgrade: WantsUpgrade, + } + } + + impl Builder { + /// Provide the initial connector. + pub fn connect(self, connect: CC) -> Builder { + Builder { + connect, + inspect: self.inspect, + fallback: self.fallback, + upgrade: self.upgrade, + } + } + + /// Provide the inspector that determines the result of the negotiation. + pub fn inspect(self, inspect: II) -> Builder { + Builder { + connect: self.connect, + inspect, + fallback: self.fallback, + upgrade: self.upgrade, + } + } + + /// Provide the layer to fallback to if negotiation fails. + pub fn fallback(self, fallback: LL) -> Builder { + Builder { + connect: self.connect, + inspect: self.inspect, + fallback, + upgrade: self.upgrade, + } + } + + /// Provide the layer to upgrade to if negotiation succeeds. + pub fn upgrade(self, upgrade: RR) -> Builder { + Builder { + connect: self.connect, + inspect: self.inspect, + fallback: self.fallback, + upgrade, + } + } + + /// Build the `Negotiate` pool. + pub fn build(self) -> Negotiate + where + C: Service, + C::Error: Into, + L: Layer>, + L::Service: Service + Clone, + >::Error: Into, + R: Layer>, + R::Service: Service<()> + Clone, + >::Error: Into, + I: Fn(&C::Response) -> bool + Clone, + { + let Builder { + connect, + inspect, + fallback, + upgrade, + } = self; + + let slot = Arc::new(Mutex::new(None)); + let wrapped = Inspector { + svc: connect, + inspect, + slot: slot.clone(), + }; + let left = fallback.layer(wrapped); + + let right = upgrade.layer(Inspected { slot }); + + Negotiate { left, right } + } + } + + impl Service for Negotiate + where + L: Service + Clone, + L::Error: Into, + R: Service<()> + Clone, + R::Error: Into, + { + type Response = Negotiated; + type Error = BoxError; + type Future = Negotiating; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.left.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, dst: Target) -> Self::Future { + let left = self.left.clone(); + Negotiating { + state: State::Eager { + future: self.right.call(()), + dst: Some(dst), + }, + // place clone, take original that we already polled-ready. + left: std::mem::replace(&mut self.left, left), + right: self.right.clone(), + } + } + } + + impl Future for Negotiating + where + L: Service, + L::Error: Into, + R: Service<()>, + R::Error: Into, + { + type Output = Result, BoxError>; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + // States: + // - `Eager`: try the "right" path first; on `UseOther` sentinel, fall back to left. + // - `Fallback`: try the left path; on `UseOther` sentinel, upgrade back to right. + // - `Upgrade`: retry the right path after a fallback. + // If all fail, give up. + let mut me = self.project(); + loop { + match me.state.as_mut().project() { + StateProj::Eager { future, dst } => match ready!(future.poll(cx)) { + Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))), + Err(err) => { + let err = err.into(); + if err.is::() { + let dst = dst.take().unwrap(); + let f = me.left.call(dst); + me.state.set(State::Fallback { future: f }); + continue; + } else { + return Poll::Ready(Err(err)); + } + } + }, + StateProj::Fallback { future } => match ready!(future.poll(cx)) { + Ok(out) => return Poll::Ready(Ok(Negotiated::Fallback(out))), + Err(err) => { + let err = err.into(); + if err.is::() { + let f = me.right.call(()); + me.state.set(State::Upgrade { future: f }); + continue; + } else { + return Poll::Ready(Err(err)); + } + } + }, + StateProj::Upgrade { future } => match ready!(future.poll(cx)) { + Ok(out) => return Poll::Ready(Ok(Negotiated::Upgraded(out))), + Err(err) => return Poll::Ready(Err(err.into())), + }, + } + } + } + } + + #[cfg(test)] + impl Negotiated { + // Could be useful? + pub(super) fn is_fallback(&self) -> bool { + matches!(self, Negotiated::Fallback(_)) + } + + pub(super) fn is_upgraded(&self) -> bool { + matches!(self, Negotiated::Upgraded(_)) + } + } + + impl Service for Negotiated + where + L: Service, + R: Service, + { + type Response = Res; + type Error = E; + type Future = NegotiatedFuture; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + match self { + Negotiated::Fallback(ref mut s) => s.poll_ready(cx), + Negotiated::Upgraded(ref mut s) => s.poll_ready(cx), + } + } + + fn call(&mut self, req: Req) -> Self::Future { + match self { + Negotiated::Fallback(ref mut s) => NegotiatedFuture::Fallback { + future: s.call(req), + }, + Negotiated::Upgraded(ref mut s) => NegotiatedFuture::Upgraded { + future: s.call(req), + }, + } + } + } + + impl Future for NegotiatedFuture + where + L: Future, + R: Future, + { + type Output = Out; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + match self.project() { + NegotiatedProj::Fallback { future } => future.poll(cx), + NegotiatedProj::Upgraded { future } => future.poll(cx), + } + } + } + + // ===== internal ===== + + pub struct Inspector { + svc: M, + inspect: I, + slot: Arc>>, + } + + pin_project! { + pub struct InspectFuture { + #[pin] + future: F, + inspect: I, + slot: Arc>>, + } + } + + impl Clone for Inspector { + fn clone(&self) -> Self { + Self { + svc: self.svc.clone(), + inspect: self.inspect.clone(), + slot: self.slot.clone(), + } + } + } + + impl Service for Inspector + where + M: Service, + M::Error: Into, + I: Clone + Fn(&S) -> bool, + { + type Response = M::Response; + type Error = BoxError; + type Future = InspectFuture; + + fn poll_ready(&mut self, cx: &mut task::Context<'_>) -> Poll> { + self.svc.poll_ready(cx).map_err(Into::into) + } + + fn call(&mut self, dst: Target) -> Self::Future { + InspectFuture { + future: self.svc.call(dst), + inspect: self.inspect.clone(), + slot: self.slot.clone(), + } + } + } + + impl Future for InspectFuture + where + F: Future>, + E: Into, + I: Fn(&S) -> bool, + { + type Output = Result; + + fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll { + let me = self.project(); + let s = ready!(me.future.poll(cx)).map_err(Into::into)?; + Poll::Ready(if (me.inspect)(&s) { + *me.slot.lock().unwrap() = Some(s); + Err(UseOther.into()) + } else { + Ok(s) + }) + } + } + + pub struct Inspected { + slot: Arc>>, + } + + impl Service for Inspected { + type Response = S; + type Error = BoxError; + type Future = std::future::Ready>; + + fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll> { + if self.slot.lock().unwrap().is_some() { + Poll::Ready(Ok(())) + } else { + Poll::Ready(Err(UseOther.into())) + } + } + + fn call(&mut self, _dst: Target) -> Self::Future { + let s = self + .slot + .lock() + .unwrap() + .take() + .ok_or_else(|| UseOther.into()); + std::future::ready(s) + } + } + + impl Clone for Inspected { + fn clone(&self) -> Inspected { + Inspected { + slot: self.slot.clone(), + } + } + } + + #[derive(Debug)] + struct UseOther; + + impl std::fmt::Display for UseOther { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + f.write_str("sentinel error; using other") + } + } + + impl std::error::Error for UseOther {} +} + +#[cfg(test)] +mod tests { + use futures_util::future; + use tower_service::Service; + use tower_test::assert_request_eq; + + #[tokio::test] + async fn not_negotiated_falls_back_to_left() { + let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>(); + + let mut negotiate = super::builder() + .connect(mock_svc) + .inspect(|_: &&str| false) + .fallback(layer_fn(|s| s)) + .upgrade(layer_fn(|s| s)) + .build(); + + crate::common::future::poll_fn(|cx| negotiate.poll_ready(cx)) + .await + .unwrap(); + + let fut = negotiate.call(()); + let nsvc = future::join(fut, async move { + assert_request_eq!(handle, ()).send_response("one"); + }) + .await + .0 + .expect("call"); + assert!(nsvc.is_fallback()); + } + + #[tokio::test] + async fn negotiated_uses_right() { + let (mock_svc, mut handle) = tower_test::mock::pair::<(), &'static str>(); + + let mut negotiate = super::builder() + .connect(mock_svc) + .inspect(|_: &&str| true) + .fallback(layer_fn(|s| s)) + .upgrade(layer_fn(|s| s)) + .build(); + + crate::common::future::poll_fn(|cx| negotiate.poll_ready(cx)) + .await + .unwrap(); + + let fut = negotiate.call(()); + let nsvc = future::join(fut, async move { + assert_request_eq!(handle, ()).send_response("one"); + }) + .await + .0 + .expect("call"); + + assert!(nsvc.is_upgraded()); + } + + fn layer_fn(f: F) -> LayerFn { + LayerFn(f) + } + + #[derive(Clone)] + struct LayerFn(F); + + impl tower_layer::Layer for LayerFn + where + F: Fn(S) -> Out, + { + type Service = Out; + fn layer(&self, inner: S) -> Self::Service { + (self.0)(inner) + } + } +}