Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
45 changes: 32 additions & 13 deletions crates/twirp/src/client.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,10 @@ use std::sync::Arc;
use std::vec;

use async_trait::async_trait;
use http::header::Entry;
use http::header::IntoHeaderName;
use http::HeaderMap;
use http::HeaderValue;
use reqwest::header::CONTENT_TYPE;
use url::Host;
use url::Url;
Expand Down Expand Up @@ -43,21 +47,21 @@ impl ClientBuilder {
}
}

/// Set the HTTP client. Without this a default HTTP client is used.
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = Some(http_client);
self
}

/// Add middleware to the client that will be called on each request.
/// Middlewares are invoked in the order they are added as part of the
/// request cycle.
pub fn with_middleware<M>(self, middleware: M) -> Self
pub fn with_middleware<M>(mut self, middleware: M) -> Self
where
M: Middleware,
{
let mut mw = self.middleware;
mw.push(Box::new(middleware));
Self {
base_url: self.base_url,
http_client: self.http_client,
handlers: self.handlers,
middleware: mw,
}
self.middleware.push(Box::new(middleware));
self
}

/// Add a handler for a service using the default host.
Expand All @@ -83,9 +87,16 @@ impl ClientBuilder {
self
}

/// Set the HTTP client. Without this a default HTTP client is used.
pub fn with_http_client(mut self, http_client: reqwest::Client) -> Self {
self.http_client = Some(http_client);
/// Set a default header for use in direct mode.
pub fn with_default_header<K>(mut self, key: K, value: HeaderValue) -> Self
where
K: IntoHeaderName,
{
if let Some(handlers) = &mut self.handlers {
handlers.default_headers.insert(key, value);
} else {
panic!("you must use `ClientBuilder::direct()` to register handler default headers");
}
self
}

Expand Down Expand Up @@ -315,9 +326,15 @@ impl<'a> Next<'a> {
}

async fn execute_handlers(
req: reqwest::Request,
mut req: reqwest::Request,
request_handlers: &RequestHandlers,
) -> Result<reqwest::Response> {
let req_headers = req.headers_mut();
for (key, value) in &request_handlers.default_headers {
if let Entry::Vacant(entry) = req_headers.entry(key) {
entry.insert(value.clone());
}
}
let url = req.url().clone();
let Some(mut segments) = url.path_segments() else {
return Err(crate::bad_route(format!(
Expand All @@ -344,13 +361,15 @@ async fn execute_handlers(

#[derive(Clone, Default)]
pub struct RequestHandlers {
default_headers: HeaderMap,
/// A map of host/service names to handlers.
handlers: HashMap<String, Arc<dyn DirectHandler>>,
}

impl RequestHandlers {
pub fn new() -> Self {
Self {
default_headers: HeaderMap::new(),
handlers: HashMap::new(),
}
}
Expand Down