diff --git a/Cargo.toml b/Cargo.toml index 4f3f82a7e..6ebf17872 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -24,6 +24,7 @@ members = [ "tonic-web/tests/integration", "tests/service_named_result", "tests/use_arc_self", + "tests/default_stubs", ] resolver = "2" diff --git a/tests/default_stubs/Cargo.toml b/tests/default_stubs/Cargo.toml new file mode 100644 index 000000000..c0989eefe --- /dev/null +++ b/tests/default_stubs/Cargo.toml @@ -0,0 +1,20 @@ +[package] +authors = ["Jordan Singh "] +edition = "2021" +license = "MIT" +name = "default_stubs" +publish = false +version = "0.1.0" + +[dependencies] +futures = "0.3" +tokio = {version = "1.0", features = ["macros", "rt-multi-thread", "net"]} +tokio-stream = {version = "0.1", features = ["net"]} +prost = "0.11" +tonic = {path = "../../tonic"} + +[build-dependencies] +tonic-build = {path = "../../tonic-build" } + +[package.metadata.cargo-machete] +ignored = ["prost"] diff --git a/tests/default_stubs/build.rs b/tests/default_stubs/build.rs new file mode 100644 index 000000000..b38b670e7 --- /dev/null +++ b/tests/default_stubs/build.rs @@ -0,0 +1,9 @@ +fn main() { + tonic_build::configure() + .compile(&["proto/test.proto"], &["proto"]) + .unwrap(); + tonic_build::configure() + .generate_default_stubs(true) + .compile(&["proto/test_default.proto"], &["proto"]) + .unwrap(); +} diff --git a/tests/default_stubs/proto/test.proto b/tests/default_stubs/proto/test.proto new file mode 100644 index 000000000..3fd6787a7 --- /dev/null +++ b/tests/default_stubs/proto/test.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package test; + +import "google/protobuf/empty.proto"; + +service Test { + rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty); + rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty); + rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty); + rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty); +} \ No newline at end of file diff --git a/tests/default_stubs/proto/test_default.proto b/tests/default_stubs/proto/test_default.proto new file mode 100644 index 000000000..b5294802b --- /dev/null +++ b/tests/default_stubs/proto/test_default.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package test_default; + +import "google/protobuf/empty.proto"; + +service TestDefault { + rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty); + rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty); + rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty); + rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty); +} \ No newline at end of file diff --git a/tests/default_stubs/src/lib.rs b/tests/default_stubs/src/lib.rs new file mode 100644 index 000000000..4f0b287c1 --- /dev/null +++ b/tests/default_stubs/src/lib.rs @@ -0,0 +1,47 @@ +#![allow(unused_imports)] + +mod test_defaults; + +use futures::{Stream, StreamExt}; +use std::pin::Pin; +use tonic::{Request, Response, Status, Streaming}; + +tonic::include_proto!("test"); +tonic::include_proto!("test_default"); + +#[derive(Debug, Default)] +struct Svc; + +#[tonic::async_trait] +impl test_server::Test for Svc { + type ServerStreamStream = Pin> + Send + 'static>>; + type BidirectionalStreamStream = + Pin> + Send + 'static>>; + + async fn unary(&self, _: Request<()>) -> Result, Status> { + Err(Status::permission_denied("")) + } + + async fn server_stream( + &self, + _: Request<()>, + ) -> Result, Status> { + Err(Status::permission_denied("")) + } + + async fn client_stream(&self, _: Request>) -> Result, Status> { + Err(Status::permission_denied("")) + } + + async fn bidirectional_stream( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::permission_denied("")) + } +} + +#[tonic::async_trait] +impl test_default_server::TestDefault for Svc { + // Default unimplemented stubs provided here. +} diff --git a/tests/default_stubs/src/test_defaults.rs b/tests/default_stubs/src/test_defaults.rs new file mode 100644 index 000000000..32bed1be1 --- /dev/null +++ b/tests/default_stubs/src/test_defaults.rs @@ -0,0 +1,112 @@ +#![allow(unused_imports)] + +use crate::*; +use std::net::SocketAddr; +use tokio::net::TcpListener; +use tonic::transport::Server; + +#[cfg(test)] +fn echo_requests_iter() -> impl Stream { + tokio_stream::iter(1..usize::MAX).map(|_| ()) +} + +#[tokio::test()] +async fn test_default_stubs() { + use tonic::Code; + + let addrs = run_services_in_background().await; + + // First validate pre-existing functionality (trait has no default implementation, we explicitly return PermissionDenied in lib.rs). + let mut client = test_client::TestClient::connect(format!("http://{}", addrs.0)) + .await + .unwrap(); + assert_eq!( + client.unary(()).await.unwrap_err().code(), + Code::PermissionDenied + ); + assert_eq!( + client.server_stream(()).await.unwrap_err().code(), + Code::PermissionDenied + ); + assert_eq!( + client + .client_stream(echo_requests_iter().take(5)) + .await + .unwrap_err() + .code(), + Code::PermissionDenied + ); + assert_eq!( + client + .bidirectional_stream(echo_requests_iter().take(5)) + .await + .unwrap_err() + .code(), + Code::PermissionDenied + ); + + // Then validate opt-in new functionality (trait has default implementation of returning Unimplemented). + let mut client_default_stubs = test_client::TestClient::connect(format!("http://{}", addrs.1)) + .await + .unwrap(); + assert_eq!( + client_default_stubs.unary(()).await.unwrap_err().code(), + Code::Unimplemented + ); + assert_eq!( + client_default_stubs + .server_stream(()) + .await + .unwrap_err() + .code(), + Code::Unimplemented + ); + assert_eq!( + client_default_stubs + .client_stream(echo_requests_iter().take(5)) + .await + .unwrap_err() + .code(), + Code::Unimplemented + ); + assert_eq!( + client_default_stubs + .bidirectional_stream(echo_requests_iter().take(5)) + .await + .unwrap_err() + .code(), + Code::Unimplemented + ); +} + +#[cfg(test)] +async fn run_services_in_background() -> (SocketAddr, SocketAddr) { + let svc = test_server::TestServer::new(Svc {}); + let svc_default_stubs = test_default_server::TestDefaultServer::new(Svc {}); + + let listener = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr = listener.local_addr().unwrap(); + + let listener_default_stubs = TcpListener::bind("127.0.0.1:0").await.unwrap(); + let addr_default_stubs = listener_default_stubs.local_addr().unwrap(); + + tokio::spawn(async move { + Server::builder() + .add_service(svc) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new(listener)) + .await + .unwrap(); + }); + + tokio::spawn(async move { + Server::builder() + .add_service(svc_default_stubs) + .serve_with_incoming(tokio_stream::wrappers::TcpListenerStream::new( + listener_default_stubs, + )) + .await + .unwrap(); + }); + + (addr, addr_default_stubs) +} diff --git a/tonic-build/src/code_gen.rs b/tonic-build/src/code_gen.rs index feb402779..0306b753a 100644 --- a/tonic-build/src/code_gen.rs +++ b/tonic-build/src/code_gen.rs @@ -13,6 +13,7 @@ pub struct CodeGenBuilder { build_transport: bool, disable_comments: HashSet, use_arc_self: bool, + generate_default_stubs: bool, } impl CodeGenBuilder { @@ -64,6 +65,12 @@ impl CodeGenBuilder { self } + /// Enable or disable returning automatic unimplemented gRPC error code for generated traits. + pub fn generate_default_stubs(&mut self, generate_default_stubs: bool) -> &mut Self { + self.generate_default_stubs = generate_default_stubs; + self + } + /// Generate client code based on `Service`. /// /// This takes some `Service` and will generate a `TokenStream` that contains @@ -93,6 +100,7 @@ impl CodeGenBuilder { &self.attributes, &self.disable_comments, self.use_arc_self, + self.generate_default_stubs, ) } } @@ -106,6 +114,7 @@ impl Default for CodeGenBuilder { build_transport: true, disable_comments: HashSet::default(), use_arc_self: false, + generate_default_stubs: false, } } } diff --git a/tonic-build/src/prost.rs b/tonic-build/src/prost.rs index 6bbefaad2..1249055d9 100644 --- a/tonic-build/src/prost.rs +++ b/tonic-build/src/prost.rs @@ -40,6 +40,7 @@ pub fn configure() -> Builder { emit_rerun_if_changed: std::env::var_os("CARGO").is_some(), disable_comments: HashSet::default(), use_arc_self: false, + generate_default_stubs: false, } } @@ -174,6 +175,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator { .attributes(self.builder.server_attributes.clone()) .disable_comments(self.builder.disable_comments.clone()) .use_arc_self(self.builder.use_arc_self) + .generate_default_stubs(self.builder.generate_default_stubs) .generate_server(&service, &self.builder.proto_path); self.servers.extend(server); @@ -249,6 +251,7 @@ pub struct Builder { pub(crate) emit_rerun_if_changed: bool, pub(crate) disable_comments: HashSet, pub(crate) use_arc_self: bool, + pub(crate) generate_default_stubs: bool, out_dir: Option, } @@ -510,6 +513,17 @@ impl Builder { self } + /// Enable or disable directing service generation to providing a default implementation for service methods. + /// When this is false all gRPC methods must be explicitly implemented. + /// When this is true any unimplemented service methods will return 'unimplemented' gRPC error code. + /// When this is true all streaming server request RPC types explicitly use tonic::codegen::BoxStream type. + /// + /// This defaults to `false`. + pub fn generate_default_stubs(mut self, enable: bool) -> Self { + self.generate_default_stubs = enable; + self + } + /// Compile the .proto files and execute code generation. pub fn compile( self, diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index 1bdb624b8..9e42789ef 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -17,6 +17,7 @@ pub(crate) fn generate_internal( attributes: &Attributes, disable_comments: &HashSet, use_arc_self: bool, + generate_default_stubs: bool, ) -> TokenStream { let methods = generate_methods( service, @@ -24,6 +25,7 @@ pub(crate) fn generate_internal( proto_path, compile_well_known_types, use_arc_self, + generate_default_stubs, ); let server_service = quote::format_ident!("{}Server", service.name()); @@ -37,6 +39,7 @@ pub(crate) fn generate_internal( server_trait.clone(), disable_comments, use_arc_self, + generate_default_stubs, ); let package = if emit_package { service.package() } else { "" }; // Transport based implementations @@ -214,6 +217,7 @@ fn generate_trait( server_trait: Ident, disable_comments: &HashSet, use_arc_self: bool, + generate_default_stubs: bool, ) -> TokenStream { let methods = generate_trait_methods( service, @@ -222,6 +226,7 @@ fn generate_trait( compile_well_known_types, disable_comments, use_arc_self, + generate_default_stubs, ); let trait_doc = generate_doc_comment(format!( " Generated trait containing gRPC methods that should be implemented for use with {}Server.", @@ -244,6 +249,7 @@ fn generate_trait_methods( compile_well_known_types: bool, disable_comments: &HashSet, use_arc_self: bool, + generate_default_stubs: bool, ) -> TokenStream { let mut stream = TokenStream::new(); @@ -266,22 +272,53 @@ fn generate_trait_methods( quote!(&self) }; - let method = match (method.client_streaming(), method.server_streaming()) { - (false, false) => { + let method = match ( + method.client_streaming(), + method.server_streaming(), + generate_default_stubs, + ) { + (false, false, true) => { + quote! { + #method_doc + async fn #name(#self_param, request: tonic::Request<#req_message>) + -> std::result::Result, tonic::Status> { + Err(tonic::Status::unimplemented("Not yet implemented")) + } + } + } + (false, false, false) => { quote! { #method_doc async fn #name(#self_param, request: tonic::Request<#req_message>) -> std::result::Result, tonic::Status>; } } - (true, false) => { + (true, false, true) => { + quote! { + #method_doc + async fn #name(#self_param, request: tonic::Request>) + -> std::result::Result, tonic::Status> { + Err(tonic::Status::unimplemented("Not yet implemented")) + } + } + } + (true, false, false) => { quote! { #method_doc async fn #name(#self_param, request: tonic::Request>) -> std::result::Result, tonic::Status>; } } - (false, true) => { + (false, true, true) => { + quote! { + #method_doc + async fn #name(#self_param, request: tonic::Request<#req_message>) + -> std::result::Result>, tonic::Status> { + Err(tonic::Status::unimplemented("Not yet implemented")) + } + } + } + (false, true, false) => { let stream = quote::format_ident!("{}Stream", method.identifier()); let stream_doc = generate_doc_comment(format!( " Server streaming response type for the {} method.", @@ -297,7 +334,16 @@ fn generate_trait_methods( -> std::result::Result, tonic::Status>; } } - (true, true) => { + (true, true, true) => { + quote! { + #method_doc + async fn #name(#self_param, request: tonic::Request>) + -> std::result::Result>, tonic::Status> { + Err(tonic::Status::unimplemented("Not yet implemented")) + } + } + } + (true, true, false) => { let stream = quote::format_ident!("{}Stream", method.identifier()); let stream_doc = generate_doc_comment(format!( " Server streaming response type for the {} method.", @@ -341,6 +387,7 @@ fn generate_methods( proto_path: &str, compile_well_known_types: bool, use_arc_self: bool, + generate_default_stubs: bool, ) -> TokenStream { let mut stream = TokenStream::new(); @@ -367,6 +414,7 @@ fn generate_methods( ident.clone(), server_trait, use_arc_self, + generate_default_stubs, ), (true, false) => generate_client_streaming( method, @@ -384,6 +432,7 @@ fn generate_methods( ident.clone(), server_trait, use_arc_self, + generate_default_stubs, ), }; @@ -464,6 +513,7 @@ fn generate_server_streaming( method_ident: Ident, server_trait: Ident, use_arc_self: bool, + generate_default_stubs: bool, ) -> TokenStream { let codec_name = syn::parse_str::(method.codec_path()).unwrap(); @@ -471,7 +521,12 @@ fn generate_server_streaming( let (request, response) = method.request_response_name(proto_path, compile_well_known_types); - let response_stream = quote::format_ident!("{}Stream", method.identifier()); + let response_stream = if !generate_default_stubs { + let stream = quote::format_ident!("{}Stream", method.identifier()); + quote!(type ResponseStream = T::#stream) + } else { + quote!(type ResponseStream = BoxStream<#response>) + }; let inner_arg = if use_arc_self { quote!(inner) @@ -485,7 +540,7 @@ fn generate_server_streaming( impl tonic::server::ServerStreamingService<#request> for #service_ident { type Response = #response; - type ResponseStream = T::#response_stream; + #response_stream; type Future = BoxFuture, tonic::Status>; fn call(&mut self, request: tonic::Request<#request>) -> Self::Future { @@ -585,6 +640,7 @@ fn generate_streaming( method_ident: Ident, server_trait: Ident, use_arc_self: bool, + generate_default_stubs: bool, ) -> TokenStream { let codec_name = syn::parse_str::(method.codec_path()).unwrap(); @@ -592,7 +648,12 @@ fn generate_streaming( let (request, response) = method.request_response_name(proto_path, compile_well_known_types); - let response_stream = quote::format_ident!("{}Stream", method.identifier()); + let response_stream = if !generate_default_stubs { + let stream = quote::format_ident!("{}Stream", method.identifier()); + quote!(type ResponseStream = T::#stream) + } else { + quote!(type ResponseStream = BoxStream<#response>) + }; let inner_arg = if use_arc_self { quote!(inner) @@ -607,7 +668,7 @@ fn generate_streaming( impl tonic::server::StreamingService<#request> for #service_ident { type Response = #response; - type ResponseStream = T::#response_stream; + #response_stream; type Future = BoxFuture, tonic::Status>; fn call(&mut self, request: tonic::Request>) -> Self::Future {