Skip to content

Commit

Permalink
feat(tonic): add Request and Response extensions (#642)
Browse files Browse the repository at this point in the history
Adds `tonic::Extensions` which is a newtype around `http::Extensions`.

Request extensions can be set by interceptors with
`Request::extensions_mut` and retrieved from RPCs with
`Request::extensions`. Extensions can also be set in tower middleware
and will be carried through to the RPC.

Since response extensions cannot be set by interceptors the main use
case is to set them in RPCs and retrieve them in tower middlewares.
Figured that might be useful.

Fixes #255
  • Loading branch information
davidpdrsn committed May 13, 2021
1 parent 74ad0a9 commit 352b0f5
Show file tree
Hide file tree
Showing 8 changed files with 321 additions and 13 deletions.
15 changes: 14 additions & 1 deletion examples/src/interceptor/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ impl Greeter for MyGreeter {
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
let extension = request.extensions().get::<MyExtension>().unwrap();
println!("extension data = {}", extension.some_piece_of_data);

let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Expand All @@ -40,7 +43,17 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
/// This function will get called on each inbound request, if a `Status`
/// is returned, it will cancel the request and return that status to the
/// client.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
fn intercept(mut req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);

// Set an extension that can be retrieved by `say_hello`
req.extensions_mut().insert(MyExtension {
some_piece_of_data: "foo".to_string(),
});

Ok(req)
}

struct MyExtension {
some_piece_of_data: String,
}
3 changes: 3 additions & 0 deletions tests/integration_tests/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,9 @@ bytes = "1.0"
[dev-dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tokio-stream = { version = "0.1.5", features = ["net"] }
tower-service = "0.3"
hyper = "0.14"
futures = "0.3"

[build-dependencies]
tonic-build = { path = "../../tonic-build" }
144 changes: 144 additions & 0 deletions tests/integration_tests/tests/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,144 @@
use futures_util::FutureExt;
use hyper::{Body, Request as HyperRequest, Response as HyperResponse};
use integration_tests::pb::{test_client, test_server, Input, Output};
use std::{
task::{Context, Poll},
time::Duration,
};
use tokio::sync::oneshot;
use tonic::{
body::BoxBody,
transport::{Endpoint, NamedService, Server},
Request, Response, Status,
};
use tower_service::Service;

struct ExtensionValue(i32);

#[tokio::test]
async fn setting_extension_from_interceptor() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let value = req.extensions().get::<ExtensionValue>().unwrap();
assert_eq!(value.0, 42);

Ok(Response::new(Output {}))
}
}

let svc = test_server::TestServer::with_interceptor(Svc, |mut req: Request<()>| {
req.extensions_mut().insert(ExtensionValue(42));
Ok(req)
});

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1323".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1323")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}

#[tokio::test]
async fn setting_extension_from_tower() {
struct Svc;

#[tonic::async_trait]
impl test_server::Test for Svc {
async fn unary_call(&self, req: Request<Input>) -> Result<Response<Output>, Status> {
let value = req.extensions().get::<ExtensionValue>().unwrap();
assert_eq!(value.0, 42);

Ok(Response::new(Output {}))
}
}

let svc = InterceptedService {
inner: test_server::TestServer::new(Svc),
};

let (tx, rx) = oneshot::channel::<()>();

let jh = tokio::spawn(async move {
Server::builder()
.add_service(svc)
.serve_with_shutdown("127.0.0.1:1324".parse().unwrap(), rx.map(drop))
.await
.unwrap();
});

tokio::time::sleep(Duration::from_millis(100)).await;

let channel = Endpoint::from_static("http://127.0.0.1:1324")
.connect()
.await
.unwrap();

let mut client = test_client::TestClient::new(channel);

client.unary_call(Input {}).await.unwrap();

tx.send(()).unwrap();

jh.await.unwrap();
}

#[derive(Debug, Clone)]
struct InterceptedService<S> {
inner: S,
}

impl<S> Service<HyperRequest<Body>> for InterceptedService<S>
where
S: Service<HyperRequest<Body>, Response = HyperResponse<BoxBody>>
+ NamedService
+ Clone
+ Send
+ 'static,
S::Future: Send + 'static,
{
type Response = S::Response;
type Error = S::Error;
type Future = futures::future::BoxFuture<'static, Result<Self::Response, Self::Error>>;

fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
self.inner.poll_ready(cx)
}

fn call(&mut self, mut req: HyperRequest<Body>) -> Self::Future {
let clone = self.inner.clone();
let mut inner = std::mem::replace(&mut self.inner, clone);

req.extensions_mut().insert(ExtensionValue(42));

Box::pin(async move {
let response = inner.call(req).await?;
Ok(response)
})
}
}

impl<S: NamedService> NamedService for InterceptedService<S> {
const NAME: &'static str = S::NAME;
}
5 changes: 3 additions & 2 deletions tonic/src/client/grpc.rs
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,8 @@ impl<T> Grpc<T> {
M1: Send + Sync + 'static,
M2: Send + Sync + 'static,
{
let (mut parts, body) = self.streaming(request, path, codec).await?.into_parts();
let (mut parts, body, extensions) =
self.streaming(request, path, codec).await?.into_parts();

futures_util::pin_mut!(body);

Expand All @@ -114,7 +115,7 @@ impl<T> Grpc<T> {
parts.merge(trailers);
}

Ok(Response::from_parts(parts, message))
Ok(Response::from_parts(parts, message, extensions))
}

/// Send a server side streaming gRPC request.
Expand Down
71 changes: 71 additions & 0 deletions tonic/src/extensions.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
use std::fmt;

/// A type map of protocol extensions.
///
/// `Extensions` can be used by [`Interceptor`] and [`Request`] to store extra data derived from
/// the underlying protocol.
///
/// [`Interceptor`]: crate::Interceptor
/// [`Request`]: crate::Request
pub struct Extensions {
inner: http::Extensions,
}

impl Extensions {
pub(crate) fn new() -> Self {
Self {
inner: http::Extensions::new(),
}
}

/// Insert a type into this `Extensions`.
///
/// If a extension of this type already existed, it will
/// be returned.
#[inline]
pub fn insert<T: Send + Sync + 'static>(&mut self, val: T) -> Option<T> {
self.inner.insert(val)
}

/// Get a reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get<T: Send + Sync + 'static>(&self) -> Option<&T> {
self.inner.get()
}

/// Get a mutable reference to a type previously inserted on this `Extensions`.
#[inline]
pub fn get_mut<T: Send + Sync + 'static>(&mut self) -> Option<&mut T> {
self.inner.get_mut()
}

/// Remove a type from this `Extensions`.
///
/// If a extension of this type existed, it will be returned.
#[inline]
pub fn remove<T: Send + Sync + 'static>(&mut self) -> Option<T> {
self.inner.remove()
}

/// Clear the `Extensions` of all inserted extensions.
#[inline]
pub fn clear(&mut self) {
self.inner.clear()
}

#[inline]
pub(crate) fn from_http(http: http::Extensions) -> Self {
Self { inner: http }
}

#[inline]
pub(crate) fn into_http(self) -> http::Extensions {
self.inner
}
}

impl fmt::Debug for Extensions {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Extensions").finish()
}
}
3 changes: 3 additions & 0 deletions tonic/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@
//! [`transport`]: transport/index.html

#![recursion_limit = "256"]
#![allow(clippy::inconsistent_struct_constructor)]
#![warn(
missing_debug_implementations,
missing_docs,
Expand All @@ -87,6 +88,7 @@ pub mod server;
#[cfg_attr(docsrs, doc(cfg(feature = "transport")))]
pub mod transport;

mod extensions;
mod interceptor;
mod macros;
mod request;
Expand All @@ -100,6 +102,7 @@ pub use async_trait::async_trait;

#[doc(inline)]
pub use codec::Streaming;
pub use extensions::Extensions;
pub use interceptor::Interceptor;
pub use request::{IntoRequest, IntoStreamingRequest, Request};
pub use response::Response;
Expand Down
Loading

0 comments on commit 352b0f5

Please sign in to comment.