Skip to content

Commit

Permalink
feat(channel): Make channel feature additive (#1574)
Browse files Browse the repository at this point in the history
* feat(channel): Make channel feature additive

* chore(channel): Move add_origin service module to channel module

* chore(channel): Move user_agent service to channel module

* chore(channel): Move reconnect service to channel module

* chore(channel): Move connection service to channel module

* chore(channel): Move discover service to channel module

* feat(channel): Remove unused Connected implement for BoxedIo

* chore(channel): Move channel part of io service to channel module

* chore(channel): Move channel part of tls service to channel module

* chore(channel): Move connector service to channel module

* chore(channel): Move executor service to channel module

* chore(channel): Clean up scope of channel service module

* chore(channel): Clean up importing in channel service

* chore(doc): Update feature flag document about transport and channel
  • Loading branch information
tottoto committed Jun 19, 2024
1 parent e2c506a commit b947e1a
Show file tree
Hide file tree
Showing 22 changed files with 242 additions and 210 deletions.
4 changes: 2 additions & 2 deletions examples/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -276,13 +276,13 @@ tracing = ["dep:tracing", "dep:tracing-subscriber"]
uds = ["tokio-stream/net", "dep:tower", "dep:hyper", "dep:hyper-util"]
streaming = ["tokio-stream", "dep:h2"]
mock = ["tokio-stream", "dep:tower", "dep:hyper-util"]
tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "dep:http"]
tower = ["dep:hyper", "dep:hyper-util", "dep:tower", "tower?/timeout", "dep:http"]
json-codec = ["dep:serde", "dep:serde_json", "dep:bytes"]
compression = ["tonic/gzip"]
tls = ["tonic/tls"]
tls-rustls = ["dep:hyper", "dep:hyper-util", "dep:hyper-rustls", "dep:tower", "tower-http/util", "tower-http/add-extension", "dep:rustls-pemfile", "dep:tokio-rustls", "dep:pin-project", "dep:http-body-util"]
dynamic-load-balance = ["dep:tower"]
timeout = ["tokio/time", "dep:tower"]
timeout = ["tokio/time", "dep:tower", "tower?/timeout"]
tls-client-auth = ["tonic/tls"]
types = ["dep:tonic-types"]
h2c = ["dep:hyper", "dep:tower", "dep:http", "dep:hyper-util"]
Expand Down
27 changes: 17 additions & 10 deletions tonic/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,24 +26,29 @@ version = "0.11.0"
codegen = ["dep:async-trait"]
gzip = ["dep:flate2"]
zstd = ["dep:zstd"]
default = ["transport", "codegen", "prost"]
default = ["channel", "codegen", "prost"]
prost = ["dep:prost"]
tls = ["dep:rustls-pemfile", "transport", "dep:tokio-rustls", "dep:tokio", "tokio?/rt", "tokio?/macros"]
tls-roots = ["tls-roots-common", "dep:rustls-native-certs"]
tls-roots-common = ["tls"]
tls-roots-common = ["tls", "channel"]
tls-webpki-roots = ["tls-roots-common", "dep:webpki-roots"]
router = ["dep:axum"]
transport = [
"router",
"dep:async-stream",
"channel",
"dep:h2",
"dep:hyper", "dep:hyper-util", "dep:hyper-timeout",
"dep:hyper", "dep:hyper-util",
"dep:socket2",
"dep:tokio", "tokio?/macros", "tokio?/net", "tokio?/time",
"dep:tower",
"dep:tower", "tower?/util", "tower?/limit",
]
channel = [
"transport",
"dep:hyper", "hyper?/client",
"dep:hyper-util", "hyper-util?/client-legacy",
"dep:tower", "tower?/balance", "tower?/buffer", "tower?/discover", "tower?/load", "tower?/make",
"dep:hyper-timeout",
]
channel = []

# [[bench]]
# name = "bench_main"
Expand Down Expand Up @@ -71,13 +76,12 @@ async-trait = {version = "0.1.13", optional = true}
# transport
async-stream = {version = "0.3", optional = true}
h2 = {version = "0.4", optional = true}
hyper = {version = "1", features = ["full"], optional = true}
hyper-util = { version = ">=0.1.4, <0.2", features = ["full"], optional = true }
hyper-timeout = {version = "0.5", optional = true}
hyper = {version = "1", features = ["http1", "http2", "server"], optional = true}
hyper-util = { version = ">=0.1.4, <0.2", features = ["service", "server-auto", "tokio"], optional = true }
socket2 = { version = ">=0.4.7, <0.6.0", optional = true, features = ["all"] }
tokio = {version = "1", default-features = false, optional = true}
tokio-stream = { version = "0.1", features = ["net"] }
tower = {version = "0.4.7", default-features = false, features = ["balance", "buffer", "discover", "limit", "load", "make", "timeout", "util"], optional = true}
tower = {version = "0.4.7", default-features = false, optional = true}
axum = {version = "0.7", default-features = false, optional = true}

# rustls
Expand All @@ -90,6 +94,9 @@ webpki-roots = { version = "0.26", optional = true }
flate2 = {version = "1.0", optional = true}
zstd = { version = "0.13.0", optional = true }

# channel
hyper-timeout = {version = "0.5", optional = true}

[dev-dependencies]
bencher = "0.1.5"
quickcheck = "1.0"
Expand Down
7 changes: 3 additions & 4 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,10 +16,9 @@
//!
//! # Feature Flags
//!
//! - `transport`: Enables the fully featured, batteries included client and server
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
//! - `channel`: Enables just the full featured channel/client portion of the `transport`
//! feature.
//! - `transport`: Enables just the full featured server portion of the `channel` feature.
//! - `channel`: Enables the fully featured, batteries included client and server
//! implementation based on [`hyper`], [`tower`] and [`tokio`]. Enabled by default.
//! - `codegen`: Enables all the required exports and optional dependencies required
//! for [`tonic-build`]. Enabled by default.
//! - `tls`: Enables the `rustls` based TLS options for the `transport` feature. Not
Expand Down
6 changes: 4 additions & 2 deletions tonic/src/status.rs
Original file line number Diff line number Diff line change
Expand Up @@ -617,8 +617,10 @@ fn find_status_in_source_chain(err: &(dyn Error + 'static)) -> Option<Status> {
// matches the spec of:
// > The service is currently unavailable. This is most likely a transient condition that
// > can be corrected if retried with a backoff.
#[cfg(feature = "transport")]
if let Some(connect) = err.downcast_ref::<crate::transport::ConnectError>() {
#[cfg(feature = "channel")]
if let Some(connect) =
err.downcast_ref::<crate::transport::channel::service::ConnectError>()
{
return Some(Status::unavailable(connect.to_string()));
}

Expand Down
8 changes: 4 additions & 4 deletions tonic/src/transport/channel/endpoint.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
use super::super::service;
#[cfg(feature = "tls")]
use super::service::TlsConnector;
use super::service::{self, Executor, SharedExec};
use super::Channel;
#[cfg(feature = "tls")]
use super::ClientTlsConfig;
#[cfg(feature = "tls")]
use crate::transport::service::TlsConnector;
use crate::transport::{service::SharedExec, Error, Executor};
use crate::transport::Error;
use bytes::Bytes;
use http::{uri::Uri, HeaderValue};
use hyper::rt;
Expand Down
4 changes: 2 additions & 2 deletions tonic/src/transport/channel/mod.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
//! Client implementation and builder.

mod endpoint;
pub(crate) mod service;
#[cfg(feature = "tls")]
#[cfg_attr(docsrs, doc(cfg(feature = "tls")))]
mod tls;
Expand All @@ -9,9 +10,8 @@ pub use endpoint::Endpoint;
#[cfg(feature = "tls")]
pub use tls::ClientTlsConfig;

use super::service::{Connection, DynamicServiceStream, SharedExec};
use self::service::{Connection, DynamicServiceStream, Executor, SharedExec};
use crate::body::BoxBody;
use crate::transport::Executor;
use bytes::Bytes;
use http::{
uri::{InvalidUri, Uri},
Expand Down
File renamed without changes.
Original file line number Diff line number Diff line change
@@ -1,8 +1,7 @@
use super::SharedExec;
use super::{grpc_timeout::GrpcTimeout, reconnect::Reconnect, AddOrigin, UserAgent};
use super::{AddOrigin, Reconnect, SharedExec, UserAgent};
use crate::{
body::{boxed, BoxBody},
transport::{BoxFuture, Endpoint},
transport::{service::GrpcTimeout, BoxFuture, Endpoint},
};
use http::Uri;
use hyper::rt;
Expand Down Expand Up @@ -36,7 +35,7 @@ impl Connection {
C::Future: Unpin + Send,
C::Response: rt::Read + rt::Write + Unpin + Send + 'static,
{
let mut settings: Builder<super::SharedExec> = Builder::new(endpoint.executor.clone())
let mut settings: Builder<SharedExec> = Builder::new(endpoint.executor.clone())
.initial_stream_window_size(endpoint.init_stream_window_size)
.initial_connection_window_size(endpoint.init_connection_window_size)
.keep_alive_interval(endpoint.http2_keep_alive_interval)
Expand Down Expand Up @@ -158,12 +157,12 @@ impl tower::Service<http::Request<BoxBody>> for SendRequest {

struct MakeSendRequestService<C> {
connector: C,
executor: super::SharedExec,
settings: Builder<super::SharedExec>,
executor: SharedExec,
settings: Builder<SharedExec>,
}

impl<C> MakeSendRequestService<C> {
fn new(connector: C, executor: SharedExec, settings: Builder<super::SharedExec>) -> Self {
fn new(connector: C, executor: SharedExec, settings: Builder<SharedExec>) -> Self {
Self {
connector,
executor,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
use super::super::BoxFuture;
use super::io::BoxedIo;
use super::BoxedIo;
#[cfg(feature = "tls")]
use super::tls::TlsConnector;
use super::TlsConnector;
use crate::transport::BoxFuture;
use http::Uri;
use std::fmt;
use std::task::{Context, Poll};
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
use super::connection::Connection;
use crate::transport::Endpoint;
use super::super::{Connection, Endpoint};

use hyper_util::client::legacy::connect::HttpConnector;
use std::{
Expand Down
File renamed without changes.
67 changes: 67 additions & 0 deletions tonic/src/transport/channel/service/io.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
use std::io::{self, IoSlice};
use std::pin::Pin;
use std::task::{Context, Poll};

use hyper::rt;
use hyper_util::client::legacy::connect::{Connected as HyperConnected, Connection};

pub(in crate::transport) trait Io:
rt::Read + rt::Write + Send + 'static
{
}

impl<T> Io for T where T: rt::Read + rt::Write + Send + 'static {}

pub(crate) struct BoxedIo(Pin<Box<dyn Io>>);

impl BoxedIo {
pub(in crate::transport) fn new<I: Io>(io: I) -> Self {
BoxedIo(Box::pin(io))
}
}

impl Connection for BoxedIo {
fn connected(&self) -> HyperConnected {
HyperConnected::new()
}
}

impl rt::Read for BoxedIo {
fn poll_read(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: rt::ReadBufCursor<'_>,
) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_read(cx, buf)
}
}

impl rt::Write for BoxedIo {
fn poll_write(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
buf: &[u8],
) -> Poll<io::Result<usize>> {
Pin::new(&mut self.0).poll_write(cx, buf)
}

fn poll_flush(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_flush(cx)
}

fn poll_shutdown(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<io::Result<()>> {
Pin::new(&mut self.0).poll_shutdown(cx)
}

fn poll_write_vectored(
mut self: Pin<&mut Self>,
cx: &mut Context<'_>,
bufs: &[IoSlice<'_>],
) -> Poll<Result<usize, io::Error>> {
Pin::new(&mut self.0).poll_write_vectored(cx, bufs)
}

fn is_write_vectored(&self) -> bool {
self.0.is_write_vectored()
}
}
28 changes: 28 additions & 0 deletions tonic/src/transport/channel/service/mod.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,28 @@
mod add_origin;
use self::add_origin::AddOrigin;

mod user_agent;
use self::user_agent::UserAgent;

mod reconnect;
use self::reconnect::Reconnect;

mod connection;
pub(super) use self::connection::Connection;

mod discover;
pub(super) use self::discover::DynamicServiceStream;

mod io;
use self::io::BoxedIo;

mod connector;
pub(crate) use self::connector::{ConnectError, Connector};

mod executor;
pub(super) use self::executor::{Executor, SharedExec};

#[cfg(feature = "tls")]
mod tls;
#[cfg(feature = "tls")]
pub(super) use self::tls::TlsConnector;
File renamed without changes.
83 changes: 83 additions & 0 deletions tonic/src/transport/channel/service/tls.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,83 @@
use std::fmt;
use std::io::Cursor;
use std::sync::Arc;

use hyper_util::rt::TokioIo;
use tokio::io::{AsyncRead, AsyncWrite};
use tokio_rustls::{
rustls::{pki_types::ServerName, ClientConfig, RootCertStore},
TlsConnector as RustlsConnector,
};

use super::io::BoxedIo;
use crate::transport::service::tls::{add_certs_from_pem, load_identity, TlsError, ALPN_H2};
use crate::transport::tls::{Certificate, Identity};

#[derive(Clone)]
pub(crate) struct TlsConnector {
config: Arc<ClientConfig>,
domain: Arc<ServerName<'static>>,
assume_http2: bool,
}

impl TlsConnector {
pub(crate) fn new(
ca_certs: Vec<Certificate>,
identity: Option<Identity>,
domain: &str,
assume_http2: bool,
) -> Result<Self, crate::Error> {
let builder = ClientConfig::builder();
let mut roots = RootCertStore::empty();

#[cfg(feature = "tls-roots")]
roots.add_parsable_certificates(rustls_native_certs::load_native_certs()?);

#[cfg(feature = "tls-webpki-roots")]
roots.extend(webpki_roots::TLS_SERVER_ROOTS.iter().cloned());

for cert in ca_certs {
add_certs_from_pem(&mut Cursor::new(cert), &mut roots)?;
}

let builder = builder.with_root_certificates(roots);
let mut config = match identity {
Some(identity) => {
let (client_cert, client_key) = load_identity(identity)?;
builder.with_client_auth_cert(client_cert, client_key)?
}
None => builder.with_no_client_auth(),
};

config.alpn_protocols.push(ALPN_H2.into());
Ok(Self {
config: Arc::new(config),
domain: Arc::new(ServerName::try_from(domain)?.to_owned()),
assume_http2,
})
}

pub(crate) async fn connect<I>(&self, io: I) -> Result<BoxedIo, crate::Error>
where
I: AsyncRead + AsyncWrite + Send + Unpin + 'static,
{
let io = RustlsConnector::from(self.config.clone())
.connect(self.domain.as_ref().to_owned(), io)
.await?;

// Generally we require ALPN to be negotiated, but if the user has
// explicitly set `assume_http2` to true, we'll allow it to be missing.
let (_, session) = io.get_ref();
let alpn_protocol = session.alpn_protocol();
if !(alpn_protocol == Some(ALPN_H2) || self.assume_http2) {
return Err(TlsError::H2NotNegotiated.into());
}
Ok(BoxedIo::new(TokioIo::new(io)))
}
}

impl fmt::Debug for TlsConnector {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsConnector").finish()
}
}
File renamed without changes.
2 changes: 1 addition & 1 deletion tonic/src/transport/channel/tls.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
use super::service::TlsConnector;
use crate::transport::{
service::TlsConnector,
tls::{Certificate, Identity},
Error,
};
Expand Down
Loading

0 comments on commit b947e1a

Please sign in to comment.