Skip to content

Commit

Permalink
feat: Add gRPC interceptors (#232)
Browse files Browse the repository at this point in the history
This change introduces proper gRPC interceptors that are avilable
regardless of the transport used. Each codegen service now produces an
additional method called `with_interceptor` that accepts a
`Interceptor`.

All examples have been updated to use this new style and interop has a
custom `tower::Service` middleware to echo the headers. There is also a
new `interceptor` example that shows basic usage.

BREAKING CHANGE: removed `interceptor_fn` and `intercep_headers_fn` from `transport` in favor of using `tonic::Interceptor`.
  • Loading branch information
LucioFranco committed Jan 14, 2020
1 parent 0fa2bf1 commit eba7ec7
Show file tree
Hide file tree
Showing 25 changed files with 456 additions and 273 deletions.
4 changes: 1 addition & 3 deletions Cargo.toml
Expand Up @@ -2,13 +2,11 @@
members = [
"tonic",
"tonic-build",

# Non-published crates
"examples",
"interop",

# Tests
"tests/included_service",
"tests/same_name",
"tests/wellknown",
]
]
19 changes: 11 additions & 8 deletions examples/Cargo.toml
Expand Up @@ -86,27 +86,30 @@ path = "src/uds/client.rs"
name = "uds-server"
path = "src/uds/server.rs"

[[bin]]
name = "interceptor-client"
path = "src/interceptor/client.rs"

[[bin]]
name = "interceptor-server"
path = "src/interceptor/server.rs"

[dependencies]
tonic = { path = "../tonic", features = ["tls"] }
prost = "0.6"

tokio = { version = "0.2", features = ["rt-threaded", "time", "stream", "fs", "macros", "uds"] }
futures = { version = "0.3", default-features = false, features = ["alloc"]}
futures = { version = "0.3", default-features = false, features = ["alloc"] }
async-stream = "0.2"
http = "0.2"
tower = "0.3"

tower = "0.3"
# Required for routeguide
serde = { version = "1.0", features = ["derive"] }
serde_json = "1.0"
rand = "0.7"

# Tracing
tracing = "0.1"
tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] }
tracing-subscriber = { version = "0.2.0-alpha", features = ["tracing-log"] }
tracing-attributes = "0.1"
tracing-futures = "0.2"

# Required for wellknown types
prost-types = "0.6"

Expand Down
22 changes: 9 additions & 13 deletions examples/src/authentication/client.rs
Expand Up @@ -2,23 +2,19 @@ pub mod pb {
tonic::include_proto!("grpc.examples.echo");
}

use http::header::HeaderValue;
use pb::{echo_client::EchoClient, EchoRequest};
use tonic::transport::Channel;
use tonic::{metadata::MetadataValue, transport::Channel, Request};

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let channel = Channel::from_static("http://[::1]:50051")
.intercept_headers(|headers| {
headers.insert(
"authorization",
HeaderValue::from_static("Bearer some-secret-token"),
);
})
.connect()
.await?;

let mut client = EchoClient::new(channel);
let channel = Channel::from_static("http://[::1]:50051").connect().await?;

let token = MetadataValue::from_str("Bearer some-auth-token")?;

let mut client = EchoClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut().insert("authorization", token.clone());
Ok(req)
});

let request = tonic::Request::new(EchoRequest {
message: "hello".into(),
Expand Down
41 changes: 11 additions & 30 deletions examples/src/authentication/server.rs
Expand Up @@ -5,8 +5,7 @@ pub mod pb {
use futures::Stream;
use pb::{EchoRequest, EchoResponse};
use std::pin::Pin;
use tonic::{body::BoxBody, transport::Server, Request, Response, Status, Streaming};
use tower::Service;
use tonic::{metadata::MetadataValue, transport::Server, Request, Response, Status, Streaming};

type EchoResult<T> = Result<Response<T>, Status>;
type ResponseStream = Pin<Box<dyn Stream<Item = Result<EchoResponse, Status>> + Send + Sync>>;
Expand Down Expand Up @@ -52,36 +51,18 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let server = EchoServer::default();

Server::builder()
.interceptor_fn(move |svc, req| {
let auth_header = req.headers().get("authorization").clone();
let svc = pb::echo_server::EchoServer::with_interceptor(server, check_auth);

let authed = if let Some(auth_header) = auth_header {
auth_header == "Bearer some-secret-token"
} else {
false
};
Server::builder().add_service(svc).serve(addr).await?;

let fut = svc.call(req);
Ok(())
}

async move {
if authed {
fut.await
} else {
// Cancel the inner future since we never await it
// the IO never gets registered.
drop(fut);
let res = http::Response::builder()
.header("grpc-status", "16")
.body(BoxBody::empty())
.unwrap();
Ok(res)
}
}
})
.add_service(pb::echo_server::EchoServer::new(server))
.serve(addr)
.await?;
fn check_auth(req: Request<()>) -> Result<Request<()>, Status> {
let token = MetadataValue::from_str("Bearer some-secret-token").unwrap();

Ok(())
match req.metadata().get("authorization") {
Some(t) if token == t => Ok(req),
_ => Err(Status::unauthenticated("No valid auth token")),
}
}
13 changes: 7 additions & 6 deletions examples/src/gcp/client.rs
Expand Up @@ -3,8 +3,8 @@ pub mod api {
}

use api::{publisher_client::PublisherClient, ListTopicsRequest};
use http::header::HeaderValue;
use tonic::{
metadata::MetadataValue,
transport::{Certificate, Channel, ClientTlsConfig},
Request,
};
Expand All @@ -23,7 +23,7 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.ok_or("Expected a project name as the first argument.".to_string())?;

let bearer_token = format!("Bearer {}", token);
let header_value = HeaderValue::from_str(&bearer_token)?;
let header_value = MetadataValue::from_str(&bearer_token)?;

let certs = tokio::fs::read("examples/data/gcp/roots.pem").await?;

Expand All @@ -32,14 +32,15 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
.domain_name("pubsub.googleapis.com");

let channel = Channel::from_static(ENDPOINT)
.intercept_headers(move |headers| {
headers.insert("authorization", header_value.clone());
})
.tls_config(tls_config)
.connect()
.await?;

let mut service = PublisherClient::new(channel);
let mut service = PublisherClient::with_interceptor(channel, move |mut req: Request<()>| {
req.metadata_mut()
.insert("authorization", header_value.clone());
Ok(req)
});

let response = service
.list_topics(Request::new(ListTopicsRequest {
Expand Down
34 changes: 34 additions & 0 deletions examples/src/interceptor/client.rs
@@ -0,0 +1,34 @@
use hello_world::greeter_client::GreeterClient;
use hello_world::HelloRequest;
use tonic::{transport::Endpoint, Request, Status};

pub mod hello_world {
tonic::include_proto!("helloworld");
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let channel = Endpoint::from_static("http://[::1]:50051")
.connect()
.await?;

let mut client = GreeterClient::with_interceptor(channel, intercept);

let request = tonic::Request::new(HelloRequest {
name: "Tonic".into(),
});

let response = client.say_hello(request).await?;

println!("RESPONSE={:?}", response);

Ok(())
}

/// This function will get called on each outbound request. Returning a
/// `Status` here will cancel the request and have that status returned to
/// the caller.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);
Ok(req)
}
46 changes: 46 additions & 0 deletions examples/src/interceptor/server.rs
@@ -0,0 +1,46 @@
use tonic::{transport::Server, Request, Response, Status};

use hello_world::greeter_server::{Greeter, GreeterServer};
use hello_world::{HelloReply, HelloRequest};

pub mod hello_world {
tonic::include_proto!("helloworld");
}

#[derive(Default)]
pub struct MyGreeter {}

#[tonic::async_trait]
impl Greeter for MyGreeter {
async fn say_hello(
&self,
request: Request<HelloRequest>,
) -> Result<Response<HelloReply>, Status> {
let reply = hello_world::HelloReply {
message: format!("Hello {}!", request.into_inner().name),
};
Ok(Response::new(reply))
}
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let addr = "[::1]:50051".parse().unwrap();
let greeter = MyGreeter::default();

let svc = GreeterServer::with_interceptor(greeter, intercept);

println!("GreeterServer listening on {}", addr);

Server::builder().add_service(svc).serve(addr).await?;

Ok(())
}

/// This function will get called on each inbound request, if a `Status`
/// is returned, it will cancel the request and return that status to the
/// client.
fn intercept(req: Request<()>) -> Result<Request<()>, Status> {
println!("Intercepting request: {:?}", req);
Ok(req)
}
3 changes: 1 addition & 2 deletions examples/src/uds/client.rs
Expand Up @@ -5,11 +5,10 @@ pub mod hello_world {
}

use hello_world::{greeter_client::GreeterClient, HelloRequest};
use http::Uri;
use std::convert::TryFrom;
#[cfg(unix)]
use tokio::net::UnixStream;
use tonic::transport::Endpoint;
use tonic::transport::{Endpoint, Uri};
use tower::service_fn;

#[cfg(unix)]
Expand Down
3 changes: 1 addition & 2 deletions interop/Cargo.toml
Expand Up @@ -26,10 +26,9 @@ futures-util = "0.3"
async-stream = "0.2"
tower = "0.3"
http-body = "0.3"

hyper = "0.13"
console = "0.9"
structopt = "0.3"

tracing = "0.1"
tracing-subscriber = "0.2.0-alpha"
tracing-log = "0.1.0"
Expand Down
38 changes: 6 additions & 32 deletions interop/src/bin/server.rs
@@ -1,10 +1,7 @@
use http::header::HeaderName;
use structopt::StructOpt;
use tonic::body::BoxBody;
use tonic::client::GrpcService;
use tonic::transport::Server;
use tonic::transport::{Identity, ServerTlsConfig};
use tonic_interop::{server, MergeTrailers};
use tonic_interop::server;

#[derive(StructOpt)]
struct Opts {
Expand All @@ -20,33 +17,7 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {

let addr = "127.0.0.1:10000".parse().unwrap();

let mut builder = Server::builder().interceptor_fn(|svc, req| {
let echo_header = req
.headers()
.get("x-grpc-test-echo-initial")
.map(Clone::clone);

let echo_trailer = req
.headers()
.get("x-grpc-test-echo-trailing-bin")
.map(Clone::clone)
.map(|v| (HeaderName::from_static("x-grpc-test-echo-trailing-bin"), v));

let call = svc.call(req);

async move {
let mut res = call.await?;

if let Some(echo_header) = echo_header {
res.headers_mut()
.insert("x-grpc-test-echo-initial", echo_header);
}

Ok(res
.map(|b| MergeTrailers::new(b, echo_trailer))
.map(BoxBody::new))
}
});
let mut builder = Server::builder();

if matches.use_tls {
let cert = tokio::fs::read("interop/data/server1.pem").await?;
Expand All @@ -60,8 +31,11 @@ async fn main() -> std::result::Result<(), Box<dyn std::error::Error>> {
let unimplemented_service =
server::UnimplementedServiceServer::new(server::UnimplementedService::default());

// Wrap this test_service with a service that will echo headers as trailers.
let test_service_svc = server::EchoHeadersSvc::new(test_service);

builder
.add_service(test_service)
.add_service(test_service_svc)
.add_service(unimplemented_service)
.serve(addr)
.await?;
Expand Down

0 comments on commit eba7ec7

Please sign in to comment.