From aaaa442b6e02f4771fb8250a5c3d82ce6b5f4286 Mon Sep 17 00:00:00 2001 From: Randolf J Date: Sat, 25 May 2024 11:33:54 -0700 Subject: [PATCH] feat: support `async_trait`-less generation --- Cargo.toml | 1 + codegen/Cargo.toml | 2 +- tests/emit_async_trait/Cargo.toml | 26 +++ tests/emit_async_trait/build.rs | 5 + tests/emit_async_trait/proto/test.proto | 12 ++ tests/emit_async_trait/src/lib.rs | 39 +++++ tonic-build/Cargo.toml | 3 +- tonic-build/src/server.rs | 201 +++++++++++++++--------- 8 files changed, 212 insertions(+), 77 deletions(-) create mode 100644 tests/emit_async_trait/Cargo.toml create mode 100644 tests/emit_async_trait/build.rs create mode 100644 tests/emit_async_trait/proto/test.proto create mode 100644 tests/emit_async_trait/src/lib.rs diff --git a/Cargo.toml b/Cargo.toml index d0db8e6a9..89bf36ad7 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -10,6 +10,7 @@ members = [ "codegen", "interop", # Tests "tests/disable_comments", + "tests/emit_async_trait", "tests/included_service", "tests/same_name", "tests/service_named_service", diff --git a/codegen/Cargo.toml b/codegen/Cargo.toml index f4d9f3290..95d48d263 100644 --- a/codegen/Cargo.toml +++ b/codegen/Cargo.toml @@ -8,4 +8,4 @@ version = "0.1.0" [dependencies] tempfile = "3.8.0" -tonic-build = {path = "../tonic-build", default-features = false, features = ["prost", "cleanup-markdown"]} +tonic-build = {path = "../tonic-build", default-features = false, features = ["prost", "async_trait", "cleanup-markdown"]} diff --git a/tests/emit_async_trait/Cargo.toml b/tests/emit_async_trait/Cargo.toml new file mode 100644 index 000000000..b1780c01c --- /dev/null +++ b/tests/emit_async_trait/Cargo.toml @@ -0,0 +1,26 @@ +[package] +edition = "2021" +license = "MIT" +name = "emit_async_trait" +publish = false +version = "0.1.0" + +# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html +[features] +# default = ["async_trait"] +async_trait = ["tonic-build/async_trait"] + +[dependencies] +tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] } +tokio-stream = { version = "0.1", features = ["net"] } +prost = "0.12" +tonic = { path = "../../tonic" } + +[build-dependencies] +tonic-build = { path = "../../tonic-build", default-features = false, features = [ + "transport", + "prost", +] } + +[package.metadata.cargo-machete] +ignored = ["prost"] diff --git a/tests/emit_async_trait/build.rs b/tests/emit_async_trait/build.rs new file mode 100644 index 000000000..bb62509ad --- /dev/null +++ b/tests/emit_async_trait/build.rs @@ -0,0 +1,5 @@ +fn main() { + tonic_build::configure() + .compile(&["proto/test.proto"], &["proto"]) + .unwrap(); +} diff --git a/tests/emit_async_trait/proto/test.proto b/tests/emit_async_trait/proto/test.proto new file mode 100644 index 000000000..3fd6787a7 --- /dev/null +++ b/tests/emit_async_trait/proto/test.proto @@ -0,0 +1,12 @@ +syntax = "proto3"; + +package test; + +import "google/protobuf/empty.proto"; + +service Test { + rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty); + rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty); + rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty); + rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty); +} \ No newline at end of file diff --git a/tests/emit_async_trait/src/lib.rs b/tests/emit_async_trait/src/lib.rs new file mode 100644 index 000000000..e807b4fe0 --- /dev/null +++ b/tests/emit_async_trait/src/lib.rs @@ -0,0 +1,39 @@ +#![allow(unused_imports)] + +use std::pin::Pin; +use tokio_stream::{Stream, StreamExt}; +use tonic::{Request, Response, Status, Streaming}; + +tonic::include_proto!("test"); + +#[derive(Debug, Default)] +struct Svc; + +#[cfg_attr(feature = "async_trait", tonic::async_trait)] +impl test_server::Test for Svc { + type ServerStreamStream = Pin> + Send + 'static>>; + type BidirectionalStreamStream = + Pin> + Send + 'static>>; + + async fn unary(&self, _: Request<()>) -> Result, Status> { + Err(Status::permission_denied("")) + } + + async fn server_stream( + &self, + _: Request<()>, + ) -> Result, Status> { + Err(Status::permission_denied("")) + } + + async fn client_stream(&self, _: Request>) -> Result, Status> { + Err(Status::permission_denied("")) + } + + async fn bidirectional_stream( + &self, + _: Request>, + ) -> Result, Status> { + Err(Status::permission_denied("")) + } +} diff --git a/tonic-build/Cargo.toml b/tonic-build/Cargo.toml index b24db759e..2a9b58b09 100644 --- a/tonic-build/Cargo.toml +++ b/tonic-build/Cargo.toml @@ -22,10 +22,11 @@ quote = "1.0" syn = "2.0" [features] -default = ["transport", "prost"] +default = ["transport", "prost", "async_trait"] prost = ["prost-build"] cleanup-markdown = ["prost", "prost-build/cleanup-markdown"] transport = [] +async_trait = [] [package.metadata.docs.rs] all-features = true diff --git a/tonic-build/src/server.rs b/tonic-build/src/server.rs index d9ab1ad6b..13fde9b12 100644 --- a/tonic-build/src/server.rs +++ b/tonic-build/src/server.rs @@ -235,9 +235,15 @@ fn generate_trait( service.name() )); + let async_trait = if cfg!(feature = "async_trait") { + quote!(#[async_trait]) + } else { + TokenStream::new() + }; + quote! { #trait_doc - #[async_trait] + #async_trait pub trait #server_trait : Send + Sync + 'static { #methods } @@ -274,91 +280,136 @@ fn generate_trait_methods( quote!(&self) }; - let method = match ( - method.client_streaming(), - method.server_streaming(), - generate_default_stubs, - ) { - (false, false, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) + let r#async = if cfg!(feature = "async_trait") { + quote!(async) + } else { + TokenStream::new() + }; + + let ret_type = |ret| { + if cfg!(feature = "async_trait") { + ret + } else { + let lifetime = if use_arc_self { + quote!('static) + } else { + quote!('_) + }; + quote!(impl std::future::Future + Send + #lifetime) + } + }; + + let method = { + match ( + method.client_streaming(), + method.server_streaming(), + generate_default_stubs, + ) { + (false, false, true) => { + let return_type = ret_type( + quote!(std::result::Result, tonic::Status>), + ); + quote! { + #method_doc + #r#async fn #name(#self_param, request: tonic::Request<#req_message>) + -> #return_type { + Err(tonic::Status::unimplemented("Not yet implemented")) + } } } - } - (false, false, false) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result, tonic::Status>; + (false, false, false) => { + let return_type = ret_type( + quote!(std::result::Result, tonic::Status>), + ); + quote! { + #method_doc + #r#async fn #name(#self_param, request: tonic::Request<#req_message>) + -> #return_type; + } } - } - (true, false, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) + (true, false, true) => { + let return_type = ret_type( + quote!(std::result::Result, tonic::Status>), + ); + quote! { + #method_doc + #r#async fn #name(#self_param, request: tonic::Request>) + -> #return_type { + Err(tonic::Status::unimplemented("Not yet implemented")) + } } } - } - (true, false, false) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result, tonic::Status>; + (true, false, false) => { + let return_type = ret_type( + quote!(std::result::Result, tonic::Status>), + ); + quote! { + #method_doc + #r#async fn #name(#self_param, request: tonic::Request>) + -> #return_type; + } } - } - (false, true, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result>, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) + (false, true, true) => { + let return_type = ret_type( + quote!(std::result::Result>, tonic::Status>), + ); + quote! { + #method_doc + #r#async fn #name(#self_param, request: tonic::Request<#req_message>) + -> #return_type { + Err(tonic::Status::unimplemented("Not yet implemented")) + } } } - } - (false, true, false) => { - let stream = quote::format_ident!("{}Stream", method.identifier()); - let stream_doc = generate_doc_comment(format!( - " Server streaming response type for the {} method.", - method.identifier() - )); - - quote! { - #stream_doc - type #stream: tonic::codegen::tokio_stream::Stream> + Send + 'static; - - #method_doc - async fn #name(#self_param, request: tonic::Request<#req_message>) - -> std::result::Result, tonic::Status>; + (false, true, false) => { + let stream = quote::format_ident!("{}Stream", method.identifier()); + let stream_doc = generate_doc_comment(format!( + " Server streaming response type for the {} method.", + method.identifier() + )); + let return_type = ret_type( + quote!(std::result::Result, tonic::Status>), + ); + + quote! { + #stream_doc + type #stream: tonic::codegen::tokio_stream::Stream> + Send + 'static; + + #method_doc + #r#async fn #name(#self_param, request: tonic::Request<#req_message>) + -> #return_type; + } } - } - (true, true, true) => { - quote! { - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result>, tonic::Status> { - Err(tonic::Status::unimplemented("Not yet implemented")) + (true, true, true) => { + let return_type = ret_type( + quote!(std::result::Result>, tonic::Status>), + ); + quote! { + #method_doc + #r#async fn #name(#self_param, request: tonic::Request>) + -> #return_type { + Err(tonic::Status::unimplemented("Not yet implemented")) + } } } - } - (true, true, false) => { - let stream = quote::format_ident!("{}Stream", method.identifier()); - let stream_doc = generate_doc_comment(format!( - " Server streaming response type for the {} method.", - method.identifier() - )); - - quote! { - #stream_doc - type #stream: tonic::codegen::tokio_stream::Stream> + Send + 'static; - - #method_doc - async fn #name(#self_param, request: tonic::Request>) - -> std::result::Result, tonic::Status>; + (true, true, false) => { + let stream = quote::format_ident!("{}Stream", method.identifier()); + let stream_doc = generate_doc_comment(format!( + " Server streaming response type for the {} method.", + method.identifier() + )); + let return_type = ret_type( + quote!(std::result::Result, tonic::Status>), + ); + + quote! { + #stream_doc + type #stream: tonic::codegen::tokio_stream::Stream> + Send + 'static; + + #method_doc + #r#async fn #name(#self_param, request: tonic::Request>) + -> #return_type; + } } } };