From 4989ac50afd30967632f203710a46dd203c3bd30 Mon Sep 17 00:00:00 2001 From: Lucio Franco Date: Thu, 24 Aug 2023 14:30:11 -0400 Subject: [PATCH] feat(web): Add `GrpcWebClientService` This adds `grpc-web` support for clients in `tonic-web`. This is done by reusing the server side encoding/decoding but wrapping it in different directions. --- examples/src/grpc-web/client.rs | 72 ++-------- tonic-web/src/call.rs | 237 ++++++++++++++++++++++++++++++-- tonic-web/src/client.rs | 114 +++++++++++++++ tonic-web/src/lib.rs | 3 + 4 files changed, 356 insertions(+), 70 deletions(-) create mode 100644 tonic-web/src/client.rs diff --git a/examples/src/grpc-web/client.rs b/examples/src/grpc-web/client.rs index 6bfd41383..a16ac674a 100644 --- a/examples/src/grpc-web/client.rs +++ b/examples/src/grpc-web/client.rs @@ -1,6 +1,5 @@ -use bytes::{Buf, BufMut, Bytes, BytesMut}; -use hello_world::{HelloReply, HelloRequest}; -use http::header::{ACCEPT, CONTENT_TYPE}; +use hello_world::{greeter_client::GreeterClient, HelloRequest}; +use tonic_web::GrpcWebClientLayer; pub mod hello_world { tonic::include_proto!("helloworld"); @@ -8,67 +7,22 @@ pub mod hello_world { #[tokio::main] async fn main() -> Result<(), Box> { - let msg = HelloRequest { - name: "Bob".to_string(), - }; + // Must use hyper directly... + let client = hyper::Client::builder().build_http(); - // a good old http/1.1 request - let request = http::Request::builder() - .version(http::Version::HTTP_11) - .method(http::Method::POST) - .uri("http://127.0.0.1:3000/helloworld.Greeter/SayHello") - .header(CONTENT_TYPE, "application/grpc-web") - .header(ACCEPT, "application/grpc-web") - .body(hyper::Body::from(encode_body(msg))) - .unwrap(); + let svc = tower::ServiceBuilder::new() + .layer(GrpcWebClientLayer::new()) + .service(client); - let client = hyper::Client::new(); + let mut client = GreeterClient::with_origin(svc, "http://127.0.0.1:3000".try_into()?); - let response = client.request(request).await.unwrap(); + let request = tonic::Request::new(HelloRequest { + name: "Tonic".into(), + }); - assert_eq!( - response.headers().get(CONTENT_TYPE).unwrap(), - "application/grpc-web+proto" - ); + let response = client.say_hello(request).await?; - let body = response.into_body(); - let reply = decode_body::(body).await; - - println!("REPLY={:?}", reply); + println!("RESPONSE={:?}", response); Ok(()) } - -// one byte for the compression flag plus four bytes for the length -const GRPC_HEADER_SIZE: usize = 5; - -fn encode_body(msg: T) -> Bytes -where - T: prost::Message, -{ - let msg_len = msg.encoded_len(); - let mut buf = BytesMut::with_capacity(GRPC_HEADER_SIZE + msg_len); - - // compression flag, 0 means "no compression" - buf.put_u8(0); - buf.put_u32(msg_len as u32); - - msg.encode(&mut buf).unwrap(); - buf.freeze() -} - -async fn decode_body(body: hyper::Body) -> T -where - T: Default + prost::Message, -{ - let mut body = hyper::body::to_bytes(body).await.unwrap(); - - // ignore the compression flag - body.advance(1); - - let len = body.get_u32(); - #[allow(clippy::let_and_return)] - let msg = T::decode(&mut body.split_to(len as usize)).unwrap(); - - msg -} diff --git a/tonic-web/src/call.rs b/tonic-web/src/call.rs index eb135cbc7..ad6474ed2 100644 --- a/tonic-web/src/call.rs +++ b/tonic-web/src/call.rs @@ -5,7 +5,7 @@ use std::task::{Context, Poll}; use base64::Engine as _; use bytes::{Buf, BufMut, Bytes, BytesMut}; use futures_core::ready; -use http::{header, HeaderMap, HeaderValue}; +use http::{header, HeaderMap, HeaderName, HeaderValue}; use http_body::{Body, SizeHint}; use pin_project::pin_project; use tokio_stream::Stream; @@ -13,6 +13,9 @@ use tonic::Status; use self::content_types::*; +// A grpc header is u8 (flag) + u32 (msg len) +const GRPC_HEADER_SIZE: usize = 1 + 4; + pub(crate) mod content_types { use http::{header::CONTENT_TYPE, HeaderMap}; @@ -43,8 +46,9 @@ const GRPC_WEB_TRAILERS_BIT: u8 = 0b10000000; #[derive(Copy, Clone, PartialEq, Debug)] enum Direction { - Request, - Response, + Decode, + Encode, + Empty, } #[derive(Copy, Clone, PartialEq, Debug)] @@ -53,35 +57,78 @@ pub(crate) enum Encoding { None, } +/// HttpBody adapter for the grpc web based services. +#[derive(Debug)] #[pin_project] -pub(crate) struct GrpcWebCall { +pub struct GrpcWebCall { #[pin] inner: B, buf: BytesMut, direction: Direction, encoding: Encoding, poll_trailers: bool, + client: bool, + trailers: Option, +} + +impl Default for GrpcWebCall { + fn default() -> Self { + Self { + inner: Default::default(), + buf: Default::default(), + direction: Direction::Empty, + encoding: Encoding::None, + poll_trailers: Default::default(), + client: Default::default(), + trailers: Default::default(), + } + } } impl GrpcWebCall { pub(crate) fn request(inner: B, encoding: Encoding) -> Self { - Self::new(inner, Direction::Request, encoding) + Self::new(inner, Direction::Decode, encoding) } pub(crate) fn response(inner: B, encoding: Encoding) -> Self { - Self::new(inner, Direction::Response, encoding) + Self::new(inner, Direction::Encode, encoding) + } + + pub(crate) fn client_request(inner: B) -> Self { + Self::new_client(inner, Direction::Encode, Encoding::None) + } + + pub(crate) fn client_response(inner: B) -> Self { + Self::new_client(inner, Direction::Decode, Encoding::None) + } + + fn new_client(inner: B, direction: Direction, encoding: Encoding) -> Self { + GrpcWebCall { + inner, + buf: BytesMut::with_capacity(match (direction, encoding) { + (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, + _ => 0, + }), + direction, + encoding, + poll_trailers: true, + client: true, + trailers: None, + } } fn new(inner: B, direction: Direction, encoding: Encoding) -> Self { GrpcWebCall { inner, buf: BytesMut::with_capacity(match (direction, encoding) { - (Direction::Response, Encoding::Base64) => BUFFER_SIZE, + (Direction::Encode, Encoding::Base64) => BUFFER_SIZE, _ => 0, }), direction, encoding, poll_trailers: true, + client: false, + trailers: None, } } @@ -192,12 +239,43 @@ where type Error = Status; fn poll_data( - self: Pin<&mut Self>, + mut self: Pin<&mut Self>, cx: &mut Context<'_>, ) -> Poll>> { + if self.client && self.direction == Direction::Decode { + let buf = ready!(self.as_mut().poll_decode(cx)); + + return if let Some(Ok(mut buf)) = buf { + // We found some trailers so extract them since we + // want to return them via `poll_trailers`. + if let Some(len) = find_trailers(&buf[..]) { + // Extract up to len of where the trailers are at + let msg_buf = buf.copy_to_bytes(len); + match decode_trailers_frame(buf) { + Ok(Some(trailers)) => { + self.project().trailers.replace(trailers); + } + Err(e) => return Poll::Ready(Some(Err(e))), + _ => {} + } + + if msg_buf.has_remaining() { + return Poll::Ready(Some(Ok(msg_buf))); + } else { + return Poll::Ready(None); + } + } + + Poll::Ready(Some(Ok(buf))) + } else { + Poll::Ready(buf) + }; + } + match self.direction { - Direction::Request => self.poll_decode(cx), - Direction::Response => self.poll_encode(cx), + Direction::Decode => self.poll_decode(cx), + Direction::Encode => self.poll_encode(cx), + Direction::Empty => Poll::Ready(None), } } @@ -205,7 +283,8 @@ where self: Pin<&mut Self>, _: &mut Context<'_>, ) -> Poll>, Self::Error>> { - Poll::Ready(Ok(None)) + let trailers = self.project().trailers.take(); + Poll::Ready(Ok(trailers)) } fn is_end_stream(&self) -> bool { @@ -268,6 +347,56 @@ fn encode_trailers(trailers: HeaderMap) -> Vec { }) } +fn decode_trailers_frame(mut buf: Bytes) -> Result, Status> { + if buf.remaining() < GRPC_HEADER_SIZE { + return Ok(None); + } + + buf.get_u8(); + buf.get_u32(); + + let mut map = HeaderMap::new(); + let mut temp_buf = buf.clone(); + + let mut trailers = Vec::new(); + let mut cursor_pos = 0; + + for (i, b) in buf.iter().enumerate() { + if b == &b'\r' && buf.get(i + 1) == Some(&b'\n') { + let trailer = temp_buf.copy_to_bytes(i - cursor_pos); + cursor_pos = i; + trailers.push(trailer); + if temp_buf.has_remaining() { + temp_buf.get_u8(); + temp_buf.get_u8(); + } + } + } + + for trailer in trailers { + let mut s = trailer.split(|b| b == &b':'); + let key = s + .next() + .ok_or_else(|| Status::internal("trailers couldn't parse key"))?; + let value = s + .next() + .ok_or_else(|| Status::internal("trailers couldn't parse value"))?; + + let value = value + .split(|b| b == &b'\r') + .next() + .ok_or_else(|| Status::internal("trailers was not escaped"))?; + + let header_key = HeaderName::try_from(key) + .map_err(|e| Status::internal(format!("Unable to parse HeaderName: {}", e)))?; + let header_value = HeaderValue::try_from(value) + .map_err(|e| Status::internal(format!("Unable to parse HeaderValue: {}", e)))?; + map.insert(header_key, header_value); + } + + Ok(Some(map)) +} + fn make_trailers_frame(trailers: HeaderMap) -> Vec { let trailers = encode_trailers(trailers); let len = trailers.len(); @@ -281,6 +410,41 @@ fn make_trailers_frame(trailers: HeaderMap) -> Vec { frame } +/// Search some buffer for grpc-web trailers headers and return +/// its location in the original buf. If `None` is returned we did +/// not find a trailers in this buffer either because its incomplete +/// or the buffer jsut contained grpc message frames. +fn find_trailers(buf: &[u8]) -> Option { + let mut len = 0; + let mut temp_buf = &buf[..]; + + loop { + // To check each frame, there must be at least GRPC_HEADER_SIZE + // amount of bytes available otherwise the buffer is incomplete. + if temp_buf.is_empty() || temp_buf.len() < GRPC_HEADER_SIZE { + return None; + } + + let header = temp_buf.get_u8(); + + if header == GRPC_WEB_TRAILERS_BIT { + return Some(len); + } + + let msg_len = temp_buf.get_u32(); + + len += msg_len as usize + 4 + 1; + + // If the msg len of a non-grpc-web trailer frame is larger than + // the overall buffer we know within that buffer there are no trailers. + if len > buf.len() { + return None; + } + + temp_buf = &buf[len as usize..]; + } +} + #[cfg(test)] mod tests { use super::*; @@ -305,4 +469,55 @@ mod tests { assert_eq!(Encoding::from_accept(&headers), case.1, "{}", case.0); } } + + #[test] + fn decode_trailers() { + let mut headers = HeaderMap::new(); + headers.insert("grpc-status", 0.try_into().unwrap()); + headers.insert("grpc-message", "this is a message".try_into().unwrap()); + + let trailers = make_trailers_frame(headers.clone()); + + let buf = Bytes::from(trailers); + + let map = decode_trailers_frame(buf).unwrap().unwrap(); + + assert_eq!(headers, map); + } + + #[test] + fn find_trailers_non_buffered() { + // Byte version of this: + // b"\x80\0\0\0\x0fgrpc-status:0\r\n" + let buf = vec![ + 128, 0, 0, 0, 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10, + ]; + + let out = find_trailers(&buf[..]); + + assert_eq!(out, Some(0)); + } + + #[test] + fn find_trailers_buffered() { + // Byte version of this: + // b"\0\0\0\0L\n$975738af-1a17-4aea-b887-ed0bbced6093\x1a$da609e9b-f470-4cc0-a691-3fd6a005a436\x80\0\0\0\x0fgrpc-status:0\r\n" + let buf = vec![ + 0, 0, 0, 0, 76, 10, 36, 57, 55, 53, 55, 51, 56, 97, 102, 45, 49, 97, 49, 55, 45, 52, + 97, 101, 97, 45, 98, 56, 56, 55, 45, 101, 100, 48, 98, 98, 99, 101, 100, 54, 48, 57, + 51, 26, 36, 100, 97, 54, 48, 57, 101, 57, 98, 45, 102, 52, 55, 48, 45, 52, 99, 99, 48, + 45, 97, 54, 57, 49, 45, 51, 102, 100, 54, 97, 48, 48, 53, 97, 52, 51, 54, 128, 0, 0, 0, + 15, 103, 114, 112, 99, 45, 115, 116, 97, 116, 117, 115, 58, 48, 13, 10, + ]; + + let out = find_trailers(&buf[..]); + + assert_eq!(out, Some(81)); + + let trailers = decode_trailers_frame(Bytes::copy_from_slice(&buf[81..])) + .unwrap() + .unwrap(); + let status = trailers.get("grpc-status").unwrap(); + assert_eq!(status.to_str().unwrap(), "0") + } } diff --git a/tonic-web/src/client.rs b/tonic-web/src/client.rs new file mode 100644 index 000000000..f3fe8d993 --- /dev/null +++ b/tonic-web/src/client.rs @@ -0,0 +1,114 @@ +use bytes::Bytes; +use futures_core::ready; +use http::header::CONTENT_TYPE; +use http::{Request, Response, Version}; +use http_body::Body; +use pin_project::pin_project; +use std::error::Error; +use std::future::Future; +use std::pin::Pin; +use std::task::{Context, Poll}; +use tower_layer::Layer; +use tower_service::Service; +use tracing::debug; + +use crate::call::content_types::GRPC_WEB; +use crate::call::GrpcWebCall; + +/// Layer implementing the grpc-web protocol for clients. +#[derive(Debug, Clone)] +pub struct GrpcWebClientLayer { + _priv: (), +} + +impl GrpcWebClientLayer { + /// Create a new grpc-web for clients layer. + pub fn new() -> GrpcWebClientLayer { + Self { _priv: () } + } +} + +impl Default for GrpcWebClientLayer { + fn default() -> Self { + Self::new() + } +} + +impl Layer for GrpcWebClientLayer { + type Service = GrpcWebClientService; + + fn layer(&self, inner: S) -> Self::Service { + GrpcWebClientService::new(inner) + } +} + +/// A [`Service`] that wraps some inner http service that will +/// coerce requests coming from [`tonic::client::Grpc`] into proper +/// `grpc-web` requests. +#[derive(Debug, Clone)] +pub struct GrpcWebClientService { + inner: S, +} + +impl GrpcWebClientService { + /// Create a new grpc-web for clients service. + pub fn new(inner: S) -> Self { + Self { inner } + } +} + +impl Service> for GrpcWebClientService +where + S: Service>, Response = Response>, + B1: Body, + B2: Body, + B2::Error: Error, +{ + type Response = Response>; + type Error = S::Error; + type Future = ResponseFuture; + + fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll> { + self.inner.poll_ready(cx) + } + + fn call(&mut self, mut req: Request) -> Self::Future { + if req.version() == Version::HTTP_2 { + debug!("coercing HTTP2 request to HTTP1.1"); + + *req.version_mut() = Version::HTTP_11; + } + + req.headers_mut() + .insert(CONTENT_TYPE, GRPC_WEB.try_into().unwrap()); + + let req = req.map(GrpcWebCall::client_request); + + let fut = self.inner.call(req); + + ResponseFuture { inner: fut } + } +} + +/// Response future for the [`GrpcWebService`]. +#[allow(missing_debug_implementations)] +#[pin_project] +#[must_use = "futures do nothing unless polled"] +pub struct ResponseFuture { + #[pin] + inner: F, +} + +impl Future for ResponseFuture +where + B: Body, + F: Future, E>>, +{ + type Output = Result>, E>; + + fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll { + let res = ready!(self.project().inner.poll(cx)); + + Poll::Ready(res.map(|r| r.map(GrpcWebCall::client_response))) + } +} diff --git a/tonic-web/src/lib.rs b/tonic-web/src/lib.rs index 7e96f80ca..4942faba2 100644 --- a/tonic-web/src/lib.rs +++ b/tonic-web/src/lib.rs @@ -97,10 +97,13 @@ #![doc(html_root_url = "https://docs.rs/tonic-web/0.9.2")] #![doc(issue_tracker_base_url = "https://github.com/hyperium/tonic/issues/")] +pub use call::GrpcWebCall; +pub use client::{GrpcWebClientLayer, GrpcWebClientService}; pub use layer::GrpcWebLayer; pub use service::{GrpcWebService, ResponseFuture}; mod call; +mod client; mod layer; mod service;