From f07e28e7da351709d9728d680a2ee95a36472c75 Mon Sep 17 00:00:00 2001 From: Yu Li Date: Tue, 27 Feb 2024 15:05:32 +0800 Subject: [PATCH 1/2] feat(volo-http): support https Signed-off-by: Yu Li --- Cargo.lock | 2 + volo-http/Cargo.toml | 11 +- volo-http/src/client/meta.rs | 58 +++-- volo-http/src/client/mod.rs | 276 +++++++++++++++++------- volo-http/src/client/request_builder.rs | 20 +- volo-http/src/client/transport.rs | 128 +++++++---- volo-http/src/client/utils.rs | 7 +- volo-http/src/context/client.rs | 27 ++- volo-http/src/context/server.rs | 6 +- volo-http/src/error/client.rs | 7 +- volo-http/src/server/mod.rs | 113 +++++----- 11 files changed, 442 insertions(+), 213 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index b3f4ecc9..9491f46d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -3036,6 +3036,8 @@ dependencies = [ "sonic-rs", "thiserror", "tokio", + "tokio-native-tls", + "tokio-rustls 0.25.0", "tracing", "volo", ] diff --git a/volo-http/Cargo.toml b/volo-http/Cargo.toml index 5cbc3998..dbc32a3f 100644 --- a/volo-http/Cargo.toml +++ b/volo-http/Cargo.toml @@ -52,6 +52,10 @@ tracing.workspace = true # server optional matchit = { workspace = true, optional = true } +# tls optional +tokio-rustls = { workspace = true, optional = true } +tokio-native-tls = { workspace = true, optional = true } + # cookie support cookie = { workspace = true, optional = true, features = ["percent-encode"] } @@ -72,11 +76,16 @@ default = [] default_client = ["client", "json"] default_server = ["server", "query", "form", "json"] -full = ["client", "server", "cookie", "query", "form", "json"] +full = ["client", "server", "rustls", "cookie", "query", "form", "json"] client = ["hyper/client"] # client core server = ["hyper/server", "dep:matchit"] # server core +__tls = [] +rustls = ["__tls", "dep:tokio-rustls", "volo/rustls"] +native-tls = ["__tls", "dep:tokio-native-tls", "volo/native-tls"] +native-tls-vendored = ["native-tls", "volo/native-tls-vendored"] + cookie = ["dep:cookie"] __serde = ["dep:serde"] # a private feature for enabling `serde` by `serde_xxx` diff --git a/volo-http/src/client/meta.rs b/volo-http/src/client/meta.rs index e86f7477..09b65dae 100644 --- a/volo-http/src/client/meta.rs +++ b/volo-http/src/client/meta.rs @@ -1,11 +1,12 @@ -use http::{header, HeaderValue}; +use std::error::Error; + +use http::header; +use http_body::Body; use motore::service::Service; use volo::context::Context; use crate::{ - context::{client::Host, ClientContext}, - error::client::ClientError, - request::ClientRequest, + context::ClientContext, error::client::ClientError, request::ClientRequest, response::ClientResponse, }; @@ -26,7 +27,9 @@ where + Send + Sync + 'static, - B: Send + 'static, + B: Body + Send + 'static, + B::Data: Send, + B::Error: Into> + 'static, { type Response = S::Response; type Error = S::Error; @@ -36,21 +39,36 @@ where cx: &mut ClientContext, mut req: ClientRequest, ) -> Result { - let config = cx.rpc_info().config(); - let host = match config.host { - Host::CalleeName => Some(HeaderValue::from_str( - cx.rpc_info().callee().service_name_ref(), - )), - Host::TargetAddress => cx - .rpc_info() - .callee() - .address() - .map(|addr| HeaderValue::from_str(&format!("{}", addr))), - Host::None => None, - }; - if let Some(Ok(val)) = host { - req.headers_mut().insert(header::HOST, val); + // `Content-Length` must be set here because the body may be changed in previous layer(s). + let exact_len = req.body().size_hint().exact(); + if let Some(len) = exact_len { + if len > 0 && req.headers().get(header::CONTENT_LENGTH).is_none() { + req.headers_mut().insert(header::CONTENT_LENGTH, len.into()); + } } - self.inner.call(cx, req).await + + let stat_enable = cx.rpc_info().config().stat_enable; + + if stat_enable { + if let Some(req_size) = exact_len { + cx.common_stats.set_req_size(req_size); + } + } + + tracing::trace!("sending request: {} {}", req.method(), req.uri()); + tracing::trace!("headers: {:?}", req.headers()); + + let res = self.inner.call(cx, req).await; + + if stat_enable { + if let Ok(response) = res.as_ref() { + cx.stats.set_status_code(response.status()); + if let Some(resp_size) = response.size_hint().exact() { + cx.common_stats.set_resp_size(resp_size); + } + } + } + + res } } diff --git a/volo-http/src/client/mod.rs b/volo-http/src/client/mod.rs index f65f772d..a0cc3029 100644 --- a/volo-http/src/client/mod.rs +++ b/volo-http/src/client/mod.rs @@ -1,21 +1,23 @@ -use std::{cell::RefCell, error::Error, sync::Arc}; +use std::{cell::RefCell, error::Error, sync::Arc, time::Duration}; use faststr::FastStr; use http::{ header::{self, HeaderMap, HeaderName, HeaderValue}, - Method, + Method, Uri, }; use metainfo::{MetaInfo, METAINFO}; use motore::{ layer::{Identity, Layer, Stack}, - make::MakeConnection, service::Service, }; use paste::paste; use volo::{ client::MkClient, context::Context, - net::{dial::DefaultMakeTransport, Address}, + net::{ + dial::{DefaultMakeTransport, MakeTransport}, + Address, + }, }; use self::{ @@ -26,7 +28,7 @@ use self::{ }; use crate::{ context::{ - client::{Config, Host, UserAgent}, + client::{CalleeName, CallerName, Config}, ClientContext, }, error::client::{builder_error, ClientError}, @@ -41,57 +43,60 @@ pub mod utils; const PKG_NAME_WITH_VER: &str = concat!(env!("CARGO_PKG_NAME"), '/', env!("CARGO_PKG_VERSION")); -pub type ClientMetaService = MetaService>; +pub type ClientMetaService = MetaService; -pub struct ClientBuilder { +pub struct ClientBuilder { config: Config, http_config: ClientConfig, + transport_config: volo::net::dial::Config, callee_name: FastStr, caller_name: FastStr, headers: HeaderMap, - user_agent: UserAgent, layer: L, mk_client: MkC, - mk_conn: MkT, + #[cfg(feature = "__tls")] + tls_config: Option, } -impl ClientBuilder { +impl ClientBuilder { /// Create a new client builder. pub fn new() -> Self { Self { config: Default::default(), http_config: Default::default(), + transport_config: Default::default(), callee_name: FastStr::empty(), caller_name: FastStr::empty(), headers: Default::default(), - user_agent: Default::default(), layer: Identity::new(), mk_client: DefaultMkClient, - mk_conn: DefaultMakeTransport::new(), + #[cfg(feature = "__tls")] + tls_config: None, } } } -impl Default for ClientBuilder { +impl Default for ClientBuilder { fn default() -> Self { Self::new() } } -impl ClientBuilder { +impl ClientBuilder { /// This is unstable now and may be changed in the future. #[doc(hidden)] - pub fn client_maker(self, new_mk_client: MkC2) -> ClientBuilder { + pub fn client_maker(self, new_mk_client: MkC2) -> ClientBuilder { ClientBuilder { config: self.config, http_config: self.http_config, + transport_config: self.transport_config, callee_name: self.callee_name, caller_name: self.caller_name, headers: self.headers, - user_agent: self.user_agent, layer: self.layer, mk_client: new_mk_client, - mk_conn: self.mk_conn, + #[cfg(feature = "__tls")] + tls_config: None, } } @@ -106,17 +111,18 @@ impl ClientBuilder { /// The current order is: foo -> bar (the request will come to foo first, and then bar). /// /// After we call `.layer(baz)`, we will get: foo -> bar -> baz. - pub fn layer(self, layer: Inner) -> ClientBuilder, MkC, MkT> { + pub fn layer(self, layer: Inner) -> ClientBuilder, MkC> { ClientBuilder { config: self.config, http_config: self.http_config, + transport_config: self.transport_config, callee_name: self.callee_name, caller_name: self.caller_name, headers: self.headers, - user_agent: self.user_agent, layer: Stack::new(layer, self.layer), mk_client: self.mk_client, - mk_conn: self.mk_conn, + #[cfg(feature = "__tls")] + tls_config: None, } } @@ -131,17 +137,18 @@ impl ClientBuilder { /// The current order is: foo -> bar (the request will come to foo first, and then bar). /// /// After we call `.layer_front(baz)`, we will get: baz -> foo -> bar. - pub fn layer_front(self, layer: Front) -> ClientBuilder, MkC, MkT> { + pub fn layer_front(self, layer: Front) -> ClientBuilder, MkC> { ClientBuilder { config: self.config, http_config: self.http_config, + transport_config: self.transport_config, callee_name: self.callee_name, caller_name: self.caller_name, headers: self.headers, - user_agent: self.user_agent, layer: Stack::new(self.layer, layer), mk_client: self.mk_client, - mk_conn: self.mk_conn, + #[cfg(feature = "__tls")] + tls_config: None, } } @@ -178,6 +185,16 @@ impl ClientBuilder { Ok(self) } + /// Set tls config for the client. + #[cfg(feature = "__tls")] + pub fn set_tls_config(mut self, tls_config: T) -> Self + where + T: Into, + { + self.tls_config = Some(Into::into(tls_config)); + self + } + /// Get a reference to the default headers of the client. pub fn headers(&self) -> &HeaderMap { &self.headers @@ -210,6 +227,39 @@ impl ClientBuilder { &mut self.http_config } + /// Get a reference to the transport configuration of the client. + pub fn transport_config(&self) -> &volo::net::dial::Config { + &self.transport_config + } + + /// Get a mutable reference to the transport configuration of the client. + pub fn transport_config_mut(&mut self) -> &mut volo::net::dial::Config { + &mut self.transport_config + } + + /// Set mode for setting `Host` in request headers, and server name when using TLS. + /// + /// Default is callee name. + pub fn set_callee_name_mode(&mut self, mode: CalleeName) -> &mut Self { + self.config.callee_name = mode; + self + } + + /// Set mode for setting `User-Agent` in request headers. + /// + /// Default is the current crate name and version. + pub fn set_caller_name_mode(&mut self, mode: CalleeName) -> &mut Self { + self.config.callee_name = mode; + self + } + + /// This is unstable now and may be changed in the future. + #[doc(hidden)] + pub fn stat_enable(&mut self, enable: bool) -> &mut Self { + self.config.stat_enable = enable; + self + } + /// Set whether HTTP/1 connections will write header names as title case at /// the socket level. /// @@ -253,64 +303,65 @@ impl ClientBuilder { self } - /// Set mode for setting `User-Agent` in request headers. - /// - /// Default is generated crate name with version, e.g., `volo-http/0.1.0` - pub fn set_user_agent(&mut self, ua: UserAgent) -> &mut Self { - self.user_agent = ua; + /// Set the maximum idle time for a connection. + pub fn set_connect_timeout(&mut self, timeout: Duration) -> &mut Self { + self.transport_config.connect_timeout = Some(timeout); self } - /// Set mode for setting `Host` in request headers. - /// - /// Default is callee name. - pub fn set_host(&mut self, host: Host) -> &mut Self { - self.config.host = host; + /// Set the maximum idle time for reading data from the connection. + pub fn set_read_timeout(&mut self, timeout: Duration) -> &mut Self { + self.transport_config.read_timeout = Some(timeout); self } - /// This is unstable now and may be changed in the future. - #[doc(hidden)] - pub fn stat_enable(&mut self, enable: bool) -> &mut Self { - self.config.stat_enable = enable; + /// Set the maximum idle time for writing data to the connection. + pub fn set_write_timeout(&mut self, timeout: Duration) -> &mut Self { + self.transport_config.write_timeout = Some(timeout); self } /// Build the HTTP client. pub fn build(mut self) -> MkC::Target where - L: Layer>>, + L: Layer>, L::Service: Send + Sync + 'static, MkC: MkClient>, - MkT: MakeConnection
, { - let transport = ClientTransport::new(self.http_config, self.mk_conn); + let mut default_mk_conn = DefaultMakeTransport::new(); + default_mk_conn.set_connect_timeout(self.transport_config.connect_timeout); + default_mk_conn.set_read_timeout(self.transport_config.read_timeout); + default_mk_conn.set_write_timeout(self.transport_config.write_timeout); + + let transport = ClientTransport::new( + self.http_config, + default_mk_conn, + #[cfg(feature = "__tls")] + self.tls_config.unwrap_or_default(), + ); let service = self.layer.layer(MetaService::new(transport)); - if self.headers.get(header::USER_AGENT).is_some() { - self.user_agent = UserAgent::None; - } - match self.user_agent { - UserAgent::PkgNameWithVersion => self.headers.insert( - header::USER_AGENT, - HeaderValue::from_static(PKG_NAME_WITH_VER), - ), - UserAgent::CallerNameWithVersion if !self.caller_name.is_empty() => { - self.headers.insert( - header::USER_AGENT, - HeaderValue::from_str(&format!( - "{}/{}", - self.caller_name, - env!("CARGO_PKG_VERSION") - )) - .expect("Invalid caller name"), - ) + + let caller_name = match &self.config.caller_name { + CallerName::PkgNameWithVersion => FastStr::from_static_str(PKG_NAME_WITH_VER), + CallerName::OriginalCallerName => self.caller_name.clone(), + CallerName::CallerNameWithVersion if !self.caller_name.is_empty() => { + FastStr::from_string(format!( + "{}/{}", + self.caller_name, + env!("CARGO_PKG_VERSION") + )) } - UserAgent::Specified(val) if !val.is_empty() => self.headers.insert( - header::USER_AGENT, - val.try_into().expect("Invalid value for User-Agent"), - ), - _ => None, + CallerName::Specified(val) => val.to_owned(), + _ => FastStr::empty(), }; + + if !caller_name.is_empty() && self.headers.get(header::USER_AGENT).is_none() { + self.headers.insert( + header::USER_AGENT, + HeaderValue::from_str(caller_name.as_str()).expect("Invalid caller name"), + ); + } + let client_inner = ClientInner { callee_name: self.callee_name, caller_name: self.caller_name, @@ -351,9 +402,16 @@ macro_rules! method_requests { }; } +impl Client<()> { + /// Create a new client builder. + pub fn builder() -> ClientBuilder { + ClientBuilder::new() + } +} + impl Client { /// Create a builder for building a request. - pub fn builder(&self) -> RequestBuilder { + pub fn request_builder(&self) -> RequestBuilder { RequestBuilder::new(self) } @@ -378,11 +436,46 @@ impl Client { impl Client { /// Send a request to the target address. + /// + /// This is a low-level method and you should build the `uri` and `request`, and get the + /// address by yourself. + /// + /// For simple usage, you can use the `get`, `post` and other methods directly. + /// + /// # Example + /// + /// ```ignore + /// use std::net::SocketAddr; + /// + /// use http::{Method, Uri}; + /// use volo::net::Address; + /// use volo_http::{body::Body, client::Client, request::ClientRequest}; + /// + /// let client = Client::builder().build(); + /// let addr: SocketAddr = "[::]:8080".parse().unwrap(); + /// let addr = Address::from(addr); + /// let resp = client + /// .send_request( + /// Uri::from_static("http://localhost:8080/"), + /// addr, + /// ClientRequest::builder() + /// .method(Method::GET) + /// .uri("/") + /// .body(Body::empty()) + /// .expect("build request failed"), + /// ) + /// .await + /// .expect("request failed") + /// .into_string() + /// .await + /// .expect("response failed to convert to string"); + /// println!("{resp:?}"); + /// ``` pub async fn send_request( &self, - host: Option<&str>, + uri: Uri, target: Address, - request: ClientRequest, + mut request: ClientRequest, ) -> Result where S: Service, Response = ClientResponse, Error = ClientError> @@ -392,18 +485,55 @@ impl Client { B: Send + 'static, { let caller_name = self.inner.caller_name.clone(); - let callee_name = if !self.inner.callee_name.is_empty() { - self.inner.callee_name.clone() - } else { - match host { - Some(host) => FastStr::from(host.to_owned()), - None => FastStr::empty(), - } + let callee_name = match self.inner.config.callee_name { + CalleeName::TargetName => match uri.host() { + // IPv6 address in URI has square brackets, but we does not need it as a + // "host name". + Some(host) => FastStr::from( + host.trim_start_matches('[') + .trim_end_matches(']') + .to_owned(), + ), + None => match &target { + Address::Ip(addr) => FastStr::from(addr.ip().to_string()), + #[cfg(target_family = "unix")] + Address::Unix(_) => FastStr::empty(), + }, + }, + CalleeName::OriginalCalleeName => self.inner.callee_name.clone(), + CalleeName::None => FastStr::empty(), }; + tracing::trace!( + "create a request with caller_name: {caller_name}, callee_name: {callee_name}" + ); + + if request.headers().get(header::HOST).is_none() && uri.host().is_some() { + let mut host = uri.host().unwrap().to_string(); + if let Some(port) = uri.port() { + host.push(':'); + host.push_str(port.as_str()); + } + if let Ok(value) = HeaderValue::from_str(&host) { + request.headers_mut().insert(header::HOST, value); + } else { + tracing::info!( + "failed to insert `Host` to headers, `{host}` is not a valid header value" + ); + } + } + let mut cx = ClientContext::new(target, true); cx.rpc_info_mut().caller_mut().set_service_name(caller_name); cx.rpc_info_mut().callee_mut().set_service_name(callee_name); cx.rpc_info_mut().set_config(self.inner.config.clone()); + #[cfg(feature = "__tls")] + { + cx.rpc_info_mut().config_mut().is_tls = match uri.scheme() { + Some(scheme) => scheme == &http::uri::Scheme::HTTPS, + None => false, + }; + } + self.call(&mut cx, request).await } } diff --git a/volo-http/src/client/request_builder.rs b/volo-http/src/client/request_builder.rs index 47e70509..c1b54e54 100644 --- a/volo-http/src/client/request_builder.rs +++ b/volo-http/src/client/request_builder.rs @@ -1,6 +1,6 @@ use std::error::Error; -use http::{HeaderMap, HeaderName, HeaderValue, Method, Request, Uri, Version}; +use http::{uri::PathAndQuery, HeaderMap, HeaderName, HeaderValue, Method, Request, Uri, Version}; use motore::service::Service; use volo::net::Address; @@ -10,9 +10,7 @@ use crate::{ client::utils::{parse_address, resolve}, context::ClientContext, error::{ - client::{ - bad_host_name, builder_error, no_uri, uri_without_path, ClientErrorInner, Result, - }, + client::{bad_host_name, builder_error, no_uri, ClientErrorInner, Result}, ClientError, }, request::ClientRequest, @@ -41,7 +39,11 @@ impl<'a, S> RequestBuilder<'a, S, Body> { method: Method, uri: Uri, ) -> Result { - let rela_uri = uri.path_and_query().ok_or(uri_without_path())?.to_owned(); + let rela_uri = uri + .path_and_query() + .map(PathAndQuery::as_str) + .unwrap_or("/") + .to_owned(); Ok(Self { client, @@ -98,8 +100,8 @@ impl<'a, S, B> RequestBuilder<'a, S, B> { pub fn uri(mut self, uri: Uri) -> Result { let rela_uri = uri .path_and_query() - .ok_or(uri_without_path())? - .to_owned() + .map(PathAndQuery::to_owned) + .unwrap_or_else(|| PathAndQuery::from_static("/")) .into(); self.uri = Some(uri); *self.request.uri_mut() = rela_uri; @@ -201,8 +203,6 @@ where } }, }; - self.client - .send_request(uri.host(), target, self.request) - .await + self.client.send_request(uri, target, self.request).await } } diff --git a/volo-http/src/client/transport.rs b/volo-http/src/client/transport.rs index 21e741de..70b6056b 100644 --- a/volo-http/src/client/transport.rs +++ b/volo-http/src/client/transport.rs @@ -1,11 +1,15 @@ use std::error::Error; -use http::header; use http_body::Body; use hyper::client::conn::http1; use hyper_util::rt::TokioIo; use motore::{make::MakeConnection, service::Service}; -use volo::{context::Context, net::Address}; +#[cfg(feature = "__tls")] +use volo::net::tls::Connector; +use volo::{ + context::Context, + net::{conn::Conn, dial::DefaultMakeTransport}, +}; use crate::{ context::ClientContext, @@ -14,13 +18,19 @@ use crate::{ response::ClientResponse, }; -pub struct ClientTransport { +pub struct ClientTransport { client: http1::Builder, - mk_conn: MkT, + mk_conn: DefaultMakeTransport, + #[cfg(feature = "__tls")] + tls_connector: volo::net::tls::TlsConnector, } -impl ClientTransport { - pub fn new(config: ClientConfig, mk_conn: MkT) -> Self { +impl ClientTransport { + pub fn new( + config: ClientConfig, + mk_conn: DefaultMakeTransport, + #[cfg(feature = "__tls")] tls_connector: volo::net::tls::TlsConnector, + ) -> Self { let mut builder = http1::Builder::new(); builder .title_case_headers(config.title_case_headers) @@ -32,40 +42,92 @@ impl ClientTransport { Self { client: builder, mk_conn, + #[cfg(feature = "__tls")] + tls_connector, } } + #[cfg(feature = "__tls")] + async fn make_connection(&self, cx: &ClientContext) -> Result { + let target_addr = cx.rpc_info().callee().address().ok_or_else(no_address)?; + let is_tls = cx.rpc_info().config().is_tls; + match target_addr { + volo::net::Address::Ip(_) if is_tls => { + let target_name = cx.rpc_info().callee().service_name_ref(); + tracing::debug!("connecting to tls target: {target_addr:?}, name: {target_name:?}"); + let conn = self + .mk_conn + .make_connection(target_addr) + .await + .map_err(|err| { + tracing::warn!("failed to make connection, error: {err}"); + request_error(err) + })?; + let tcp_stream = match conn.stream { + volo::net::conn::ConnStream::Tcp(tcp_stream) => tcp_stream, + _ => unreachable!(), + }; + self.tls_connector + .connect(target_name, tcp_stream) + .await + .map_err(|err| { + tracing::warn!("failed to make tls connection, error: {err}"); + request_error(err) + }) + } + _ => { + tracing::debug!("fallback to non-tls target: {target_addr:?}"); + self.mk_conn + .make_connection(target_addr) + .await + .map_err(|err| { + tracing::warn!("failed to make connection, error: {err}"); + request_error(err) + }) + } + } + } + + #[cfg(not(feature = "__tls"))] + async fn make_connection(&self, cx: &ClientContext) -> Result { + let target_addr = cx.rpc_info().callee().address().ok_or_else(no_address)?; + tracing::debug!("connecting to target: {target_addr:?}"); + self.mk_conn + .make_connection(target_addr) + .await + .map_err(|err| { + tracing::warn!("failed to make connection, error: {err}"); + request_error(err) + }) + } + async fn request( &self, - target: Address, + cx: &ClientContext, req: ClientRequest, ) -> Result where - MkT: MakeConnection
+ Send + Sync, - MkT::Connection: 'static, - MkT::Error: Error + Send + Sync + 'static, B: Body + Send + 'static, B::Data: Send, B::Error: Into> + 'static, { - let conn = self - .mk_conn - .make_connection(target) - .await - .map_err(request_error)?; + let conn = self.make_connection(cx).await?; let io = TokioIo::new(conn); - let (mut sender, conn) = self.client.handshake(io).await.map_err(request_error)?; + let (mut sender, conn) = self.client.handshake(io).await.map_err(|err| { + tracing::warn!("failed to handshake, error: {err}"); + request_error(err) + })?; tokio::spawn(conn); - let resp = sender.send_request(req).await.map_err(request_error)?; + let resp = sender.send_request(req).await.map_err(|err| { + tracing::warn!("failed to send request, error: {err}"); + request_error(err) + })?; Ok(resp) } } -impl Service> for ClientTransport +impl Service> for ClientTransport where - MkT: MakeConnection
+ Send + Sync, - MkT::Connection: 'static, - MkT::Error: Error + Send + Sync + 'static, B: Body + Send + 'static, B::Data: Send, B::Error: Into> + 'static, @@ -76,39 +138,21 @@ where async fn call( &self, cx: &mut ClientContext, - mut req: ClientRequest, + req: ClientRequest, ) -> Result { - // `Content-Length` must be set here because the body may be changed in previous layer(s). - if let Some(len) = req.body().size_hint().exact() { - if req.headers().get(header::CONTENT_LENGTH).is_none() { - req.headers_mut().insert(header::CONTENT_LENGTH, len.into()); - } - } - - let target = cx.rpc_info.callee().address().ok_or_else(no_address)?; let stat_enable = cx.rpc_info().config().stat_enable; if stat_enable { - if let Some(req_size) = req.size_hint().exact() { - cx.common_stats.set_req_size(req_size); - } cx.stats.record_transport_start_at(); } - let resp = self.request(target, req).await; + let res = self.request(cx, req).await; if stat_enable { cx.stats.record_transport_end_at(); - - if let Ok(response) = resp.as_ref() { - cx.stats.set_status_code(response.status()); - if let Some(resp_size) = response.size_hint().exact() { - cx.common_stats.set_resp_size(resp_size); - } - } } - resp + res } } diff --git a/volo-http/src/client/utils.rs b/volo-http/src/client/utils.rs index edd1a2a5..3720ad36 100644 --- a/volo-http/src/client/utils.rs +++ b/volo-http/src/client/utils.rs @@ -70,8 +70,13 @@ fn get_port(uri: &Uri) -> Result { } pub fn parse_address(uri: &Uri) -> Result
{ - let host = uri.host().ok_or(uri_without_host())?; + let host = uri + .host() + .ok_or(uri_without_host())? + .trim_start_matches('[') + .trim_end_matches(']'); let port = get_port(uri)?; + tracing::warn!("host: {host}, port: {port}"); match host.parse::() { Ok(addr) => Ok(Address::from(SocketAddr::new(addr, port))), Err(e) => Err(builder_error(e)), diff --git a/volo-http/src/context/client.rs b/volo-http/src/context/client.rs index d9e9a7af..1b2ac844 100644 --- a/volo-http/src/context/client.rs +++ b/volo-http/src/context/client.rs @@ -1,4 +1,5 @@ use chrono::{DateTime, Local}; +use faststr::FastStr; use http::StatusCode; use paste::paste; use volo::{ @@ -57,39 +58,47 @@ impl ClientStats { #[derive(Clone, Debug)] pub struct Config { - pub host: Host, + pub caller_name: CallerName, + pub callee_name: CalleeName, pub stat_enable: bool, + #[cfg(feature = "__tls")] + pub is_tls: bool, } impl Default for Config { fn default() -> Self { Self { - host: Host::default(), + caller_name: CallerName::default(), + callee_name: CalleeName::default(), stat_enable: true, + #[cfg(feature = "__tls")] + is_tls: false, } } } #[derive(Clone, Debug, Default)] -pub enum UserAgent { +pub enum CallerName { /// The crate name and version of the current crate. #[default] PkgNameWithVersion, + /// The original caller name. + OriginalCallerName, /// The caller name and version of the current crate. CallerNameWithVersion, /// A specified String for the user agent. - Specified(String), + Specified(FastStr), /// Do not set `User-Agent` by the client. None, } #[derive(Clone, Debug, Default)] -pub enum Host { - /// The callee name. +pub enum CalleeName { + /// The target authority of URI. #[default] - CalleeName, - /// The target address. - TargetAddress, + TargetName, + /// The original callee name. + OriginalCalleeName, /// Do not set `Host` by the client. None, } diff --git a/volo-http/src/context/server.rs b/volo-http/src/context/server.rs index 5688f5e7..f3fb7837 100644 --- a/volo-http/src/context/server.rs +++ b/volo-http/src/context/server.rs @@ -145,11 +145,7 @@ impl RequestPartsExt for Parts { } fn scheme(&self) -> Scheme { - self.uri - .scheme() - // volo-http supports http only now - .unwrap_or(&Scheme::HTTP) - .to_owned() + self.uri.scheme().unwrap_or(&Scheme::HTTP).to_owned() } fn host(&self) -> Option<&str> { diff --git a/volo-http/src/error/client.rs b/volo-http/src/error/client.rs index f719a11a..9c99004d 100644 --- a/volo-http/src/error/client.rs +++ b/volo-http/src/error/client.rs @@ -72,10 +72,10 @@ macro_rules! error_kind { } macro_rules! client_error_inner { - ($($kind:ident => $name:ident => $msg:literal,)+) => { + ($($(#[$attr:meta])* $kind:ident => $name:ident => $msg:literal,)+) => { #[derive(Debug)] pub enum ClientErrorInner { - $($name,)+ + $($(#[$attr])* $name,)+ Other(BoxError), } @@ -83,6 +83,7 @@ macro_rules! client_error_inner { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { match self { $( + $(#[$attr])* Self::$name => f.write_str($msg), )+ Self::Other(err) => write!(f, "{}", err), @@ -94,6 +95,7 @@ macro_rules! client_error_inner { paste! { $( + $(#[$attr])* pub(crate) fn [<$name:snake>]() -> ClientError { ClientError::new(Kind::$kind, ClientErrorInner::$name) } @@ -112,6 +114,5 @@ client_error_inner! { Builder => UriWithoutHost => "host not found in uri", Builder => BadScheme => "bad scheme", Builder => BadHostName => "bad host name", - Builder => UriWithoutPath => "path not found in uri", Request => NoAddress => "missing target address", } diff --git a/volo-http/src/server/mod.rs b/volo-http/src/server/mod.rs index 0ca4d6eb..cbe051a2 100644 --- a/volo-http/src/server/mod.rs +++ b/volo-http/src/server/mod.rs @@ -16,9 +16,11 @@ use motore::{ BoxError, }; use scopeguard::defer; -use tokio::sync::Notify; +use tokio::{sync::Notify, task::JoinHandle}; use tracing::{info, trace}; use volo::net::{conn::Conn, incoming::Incoming, Address, MakeIncoming}; +#[cfg(feature = "__tls")] +use volo::net::{conn::ConnStream, tls::Acceptor, tls::ServerTlsConfig}; use crate::{ context::{server::Config, ServerContext}, @@ -54,6 +56,8 @@ pub struct Server { http_config: ServerConfig, stat_tracer: Vec, shutdown_hooks: Vec BoxFuture<'static, ()> + Send>>, + #[cfg(feature = "__tls")] + tls_config: Option, } impl Server { @@ -66,11 +70,22 @@ impl Server { http_config: ServerConfig::default(), stat_tracer: Vec::new(), shutdown_hooks: Vec::new(), + #[cfg(feature = "__tls")] + tls_config: None, } } } impl Server { + #[cfg(feature = "__tls")] + /// Enable TLS with the specified configuration. + /// + /// If not set, the server will not use TLS. + pub fn tls_config(mut self, config: impl Into) -> Self { + self.tls_config = Some(config.into()); + self + } + /// Register shutdown hook. /// /// Hook functions will be called just before volo's own gracefull existing code starts, @@ -102,6 +117,8 @@ impl Server { http_config: self.http_config, stat_tracer: self.stat_tracer, shutdown_hooks: self.shutdown_hooks, + #[cfg(feature = "__tls")] + tls_config: self.tls_config, } } @@ -124,6 +141,8 @@ impl Server { http_config: self.http_config, stat_tracer: self.stat_tracer, shutdown_hooks: self.shutdown_hooks, + #[cfg(feature = "__tls")] + tls_config: self.tls_config, } } @@ -223,9 +242,9 @@ impl Server { /// The main entry point for the server. pub async fn run(self, mk_incoming: MI) -> Result<(), BoxError> where - S: Service, + S: Service + Send + Sync + 'static, S::Response: IntoResponse, - L: Layer, + L: Layer + Send + Sync + 'static, L::Service: Service + Send + Sync + 'static, >::Response: IntoResponse, @@ -248,35 +267,51 @@ impl Server { let (exit_notify_inner, exit_flag_inner) = (exit_notify.clone(), exit_flag.clone()); // spawn accept loop - let handler = tokio::spawn(async move { + let handler: JoinHandle> = tokio::spawn(async move { let exit_flag = exit_flag_inner.clone(); loop { if *exit_flag.read() { break Ok(()); } - match incoming.accept().await { - Ok(Some(conn)) => { - let peer = conn - .info - .peer_addr - .clone() - .expect("http address should have one"); - - trace!("[VOLO] accept connection from: {:?}", peer); - - tokio::task::spawn(handle_conn( - conn, - service.clone(), - self.config, - stat_tracer.clone(), - exit_notify_inner.clone(), - conn_cnt.clone(), - peer, - )); + let conn = match incoming.accept().await? { + Some(conn) => conn, + None => break Ok(()), + }; + #[cfg(feature = "__tls")] + let conn = { + let Conn { stream, info } = conn; + match (stream, &self.tls_config) { + (ConnStream::Tcp(stream), Some(tls_config)) => { + let stream = match tls_config.acceptor.accept(stream).await { + Ok(conn) => conn, + Err(err) => { + trace!("[VOLO] tls handshake error: {err:?}"); + continue; + } + }; + Conn { stream, info } + } + (stream, _) => Conn { stream, info }, } - Ok(None) => break Ok(()), - Err(e) => break Err(Box::new(e)), - } + }; + + let peer = conn + .info + .peer_addr + .clone() + .expect("no peer address found in server connection"); + + trace!("[VOLO] accept connection from: {:?}", peer); + + tokio::task::spawn(handle_conn( + conn, + service.clone(), + self.config, + stat_tracer.clone(), + exit_notify_inner.clone(), + conn_cnt.clone(), + peer, + )); } }); @@ -295,17 +330,7 @@ impl Server { _ = sigint.recv() => {} _ = sighup.recv() => {} _ = sigterm.recv() => {} - res = handler => { - match res { - Ok(res) => { - match res { - Ok(()) => {} - Err(e) => return Err(Box::new(e)) - }; - } - Err(e) => return Err(Box::new(e)), - } - } + res = handler => res??, } } @@ -313,17 +338,7 @@ impl Server { #[cfg(target_family = "windows")] tokio::select! { _ = tokio::signal::ctrl_c() => {} - res = handler => { - match res { - Ok(res) => { - match res { - Ok(()) => {} - Err(e) => return Err(Box::new(e)) - }; - } - Err(e) => return Err(Box::new(e)), - } - } + res = handler => res??, } if !self.shutdown_hooks.is_empty() { @@ -433,7 +448,7 @@ async fn handle_conn( } result = &mut http_conn => { if let Err(err) = result { - tracing::debug!("[VOLO] http connection error: {:?}", err); + tracing::debug!("[VOLO] connection error: {:?}", err); } }, } From 928c08d694cd02184af1b4abc14f10952f82549b Mon Sep 17 00:00:00 2001 From: Yu Li Date: Fri, 8 Mar 2024 20:34:32 +0800 Subject: [PATCH 2/2] chore(volo-http): add https examples Signed-off-by: Yu Li --- examples/Cargo.toml | 15 ++++++- examples/src/http/example-http-client.rs | 13 +++++- examples/src/http/http-tls-client.rs | 32 +++++++++++++++ examples/src/http/http-tls-server.rs | 50 ++++++++++++++++++++++++ 4 files changed, 108 insertions(+), 2 deletions(-) create mode 100644 examples/src/http/http-tls-client.rs create mode 100644 examples/src/http/http-tls-server.rs diff --git a/examples/Cargo.toml b/examples/Cargo.toml index 0e81b3c9..dddf5569 100644 --- a/examples/Cargo.toml +++ b/examples/Cargo.toml @@ -87,6 +87,16 @@ path = "src/http/example-http-server.rs" name = "example-http-client" path = "src/http/example-http-client.rs" +[[bin]] +name = "http-tls-server" +path = "src/http/http-tls-server.rs" +required-features = ["__tls"] + +[[bin]] +name = "http-tls-client" +path = "src/http/http-tls-client.rs" +required-features = ["__tls"] + [dependencies] anyhow.workspace = true async-stream.workspace = true @@ -108,7 +118,7 @@ pilota.workspace = true volo = { path = "../volo" } volo-grpc = { path = "../volo-grpc" } volo-thrift = { path = "../volo-thrift", features = ["multiplex"] } -volo-http = { path = "../volo-http", features = ["full"] } +volo-http = { path = "../volo-http", features = ["default_client", "default_server", "cookie"] } volo-gen = { path = "./volo-gen" } @@ -118,14 +128,17 @@ rustls = [ "__tls", "volo/rustls", "volo-grpc/rustls", + "volo-http/rustls", ] native-tls = [ "__tls", "volo/native-tls", "volo-grpc/native-tls", + "volo-http/native-tls", ] native-tls-vendored = [ "__tls", "volo/native-tls-vendored", "volo-grpc/native-tls-vendored", + "volo-http/native-tls-vendored", ] diff --git a/examples/src/http/example-http-client.rs b/examples/src/http/example-http-client.rs index 8692dc15..bd04a36a 100644 --- a/examples/src/http/example-http-client.rs +++ b/examples/src/http/example-http-client.rs @@ -21,7 +21,18 @@ async fn main() -> Result<(), BoxError> { tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); // simple `get` function and dns resolve - println!("{}", get("http://www.126.com/").await?.into_string().await?); + println!( + "{}", + get("http://httpbin.org/get").await?.into_string().await? + ); + + // HTTPS `get` + // + // If tls is not enabled, the `httpbin.org` will response 400 Bad Request. + println!( + "{}", + get("https://httpbin.org/get").await?.into_string().await? + ); // create client by builder let client = ClientBuilder::new() diff --git a/examples/src/http/http-tls-client.rs b/examples/src/http/http-tls-client.rs new file mode 100644 index 00000000..13d478ab --- /dev/null +++ b/examples/src/http/http-tls-client.rs @@ -0,0 +1,32 @@ +use volo::net::tls::TlsConnector; +use volo_http::{body::BodyConversion, client::Client}; + +#[volo::main] +async fn main() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + + let data_dir = std::path::PathBuf::from_iter([std::env!("CARGO_MANIFEST_DIR"), "data"]); + let connector = TlsConnector::builder() + .enable_default_root_certs(false) + .add_pem_from_file(data_dir.join("tls/ca.pem")) + .expect("failed to read ca.pem") + .build() + .expect("failed to build TlsConnector"); + + let client = Client::builder().set_tls_config(connector).build(); + + let resp = client + .get("https://[::1]:8080/") + .expect("invalid uri") + .send() + .await + .expect("request failed") + .into_string() + .await + .expect("response failed to convert to string"); + + println!("{resp}"); +} diff --git a/examples/src/http/http-tls-server.rs b/examples/src/http/http-tls-server.rs new file mode 100644 index 00000000..9baef45f --- /dev/null +++ b/examples/src/http/http-tls-server.rs @@ -0,0 +1,50 @@ +//! Test it with: +//! +//! ```bash +//! curl -v --cacert examples/data/tls/ca.pem https://127.0.0.1:8080/ +//! ``` +//! +//! Or use the tls client directly. + +use std::{net::SocketAddr, time::Duration}; + +use volo::net::tls::ServerTlsConfig; +use volo_http::server::{ + layer::TimeoutLayer, + route::{get, Router}, + Server, +}; + +async fn index() -> &'static str { + "It Works!\n" +} + +#[volo::main] +async fn main() { + let subscriber = tracing_subscriber::FmtSubscriber::builder() + .with_max_level(tracing::Level::TRACE) + .finish(); + tracing::subscriber::set_global_default(subscriber).expect("setting default subscriber failed"); + + let data_dir = std::path::PathBuf::from_iter([std::env!("CARGO_MANIFEST_DIR"), "data"]); + let tls_config = ServerTlsConfig::from_pem_file( + data_dir.join("tls/server.pem"), + data_dir.join("tls/server.key"), + ) + .expect("failed to load certs"); + + let app = Router::new() + .route("/", get(index)) + .layer(TimeoutLayer::new(Duration::from_secs(5))); + + let addr: SocketAddr = "[::]:8080".parse().unwrap(); + let addr = volo::net::Address::from(addr); + + println!("Listening on {addr}"); + + Server::new(app) + .tls_config(tls_config) + .run(addr) + .await + .unwrap(); +}