diff --git a/crates/twirp/src/client.rs b/crates/twirp/src/client.rs index 1aa6831..4ee5f5b 100644 --- a/crates/twirp/src/client.rs +++ b/crates/twirp/src/client.rs @@ -3,6 +3,10 @@ use std::sync::Arc; use std::vec; use async_trait::async_trait; +use http::header::Entry; +use http::header::IntoHeaderName; +use http::HeaderMap; +use http::HeaderValue; use reqwest::header::CONTENT_TYPE; use url::Host; use url::Url; @@ -43,21 +47,21 @@ impl ClientBuilder { } } + /// Set the HTTP client. Without this a default HTTP client is used. + pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self { + self.http_client = Some(http_client); + self + } + /// Add middleware to the client that will be called on each request. /// Middlewares are invoked in the order they are added as part of the /// request cycle. - pub fn with_middleware(self, middleware: M) -> Self + pub fn with_middleware(mut self, middleware: M) -> Self where M: Middleware, { - let mut mw = self.middleware; - mw.push(Box::new(middleware)); - Self { - base_url: self.base_url, - http_client: self.http_client, - handlers: self.handlers, - middleware: mw, - } + self.middleware.push(Box::new(middleware)); + self } /// Add a handler for a service using the default host. @@ -83,9 +87,16 @@ impl ClientBuilder { self } - /// Set the HTTP client. Without this a default HTTP client is used. - pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self { - self.http_client = Some(http_client); + /// Set a default header for use in direct mode. + pub fn with_default_header(mut self, key: K, value: HeaderValue) -> Self + where + K: IntoHeaderName, + { + if let Some(handlers) = &mut self.handlers { + handlers.default_headers.insert(key, value); + } else { + panic!("you must use `ClientBuilder::direct()` to register handler default headers"); + } self } @@ -315,9 +326,15 @@ impl<'a> Next<'a> { } async fn execute_handlers( - req: reqwest::Request, + mut req: reqwest::Request, request_handlers: &RequestHandlers, ) -> Result { + let req_headers = req.headers_mut(); + for (key, value) in &request_handlers.default_headers { + if let Entry::Vacant(entry) = req_headers.entry(key) { + entry.insert(value.clone()); + } + } let url = req.url().clone(); let Some(mut segments) = url.path_segments() else { return Err(crate::bad_route(format!( @@ -344,6 +361,7 @@ async fn execute_handlers( #[derive(Clone, Default)] pub struct RequestHandlers { + default_headers: HeaderMap, /// A map of host/service names to handlers. handlers: HashMap>, } @@ -351,6 +369,7 @@ pub struct RequestHandlers { impl RequestHandlers { pub fn new() -> Self { Self { + default_headers: HeaderMap::new(), handlers: HashMap::new(), } }