Skip to content

Commit

Permalink
feat: support async_trait-less generation
Browse files Browse the repository at this point in the history
  • Loading branch information
jun-sheaf committed May 25, 2024
1 parent a0159d3 commit a34b259
Show file tree
Hide file tree
Showing 8 changed files with 212 additions and 77 deletions.
1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ members = [
"codegen",
"interop", # Tests
"tests/disable_comments",
"tests/emit_async_trait",
"tests/included_service",
"tests/same_name",
"tests/service_named_service",
Expand Down
2 changes: 1 addition & 1 deletion codegen/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -8,4 +8,4 @@ version = "0.1.0"

[dependencies]
tempfile = "3.8.0"
tonic-build = {path = "../tonic-build", default-features = false, features = ["prost", "cleanup-markdown"]}
tonic-build = {path = "../tonic-build", default-features = false, features = ["prost", "async_trait", "cleanup-markdown"]}
26 changes: 26 additions & 0 deletions tests/emit_async_trait/Cargo.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
[package]
edition = "2021"
license = "MIT"
name = "emit_async_trait"
publish = false
version = "0.1.0"

# See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html
[features]
# default = ["async_trait"]
async_trait = ["tonic-build/async_trait"]

[dependencies]
tokio = { version = "1.0", features = ["macros", "rt-multi-thread", "net"] }
tokio-stream = { version = "0.1", features = ["net"] }
prost = "0.12"
tonic = { path = "../../tonic" }

[build-dependencies]
tonic-build = { path = "../../tonic-build", default-features = false, features = [
"transport",
"prost",
] }

[package.metadata.cargo-machete]
ignored = ["prost"]
5 changes: 5 additions & 0 deletions tests/emit_async_trait/build.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
fn main() {
tonic_build::configure()
.compile(&["proto/test.proto"], &["proto"])
.unwrap();
}
12 changes: 12 additions & 0 deletions tests/emit_async_trait/proto/test.proto
Original file line number Diff line number Diff line change
@@ -0,0 +1,12 @@
syntax = "proto3";

package test;

import "google/protobuf/empty.proto";

service Test {
rpc Unary(google.protobuf.Empty) returns (google.protobuf.Empty);
rpc ServerStream(google.protobuf.Empty) returns (stream google.protobuf.Empty);
rpc ClientStream(stream google.protobuf.Empty) returns (google.protobuf.Empty);
rpc BidirectionalStream(stream google.protobuf.Empty) returns (stream google.protobuf.Empty);
}
39 changes: 39 additions & 0 deletions tests/emit_async_trait/src/lib.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
#![allow(unused_imports)]

use std::pin::Pin;
use tokio_stream::{Stream, StreamExt};
use tonic::{Request, Response, Status, Streaming};

tonic::include_proto!("test");

#[derive(Debug, Default)]
struct Svc;

#[cfg_attr(feature = "async_trait", tonic::async_trait)]
impl test_server::Test for Svc {
type ServerStreamStream = Pin<Box<dyn Stream<Item = Result<(), Status>> + Send + 'static>>;
type BidirectionalStreamStream =
Pin<Box<dyn Stream<Item = Result<(), Status>> + Send + 'static>>;

async fn unary(&self, _: Request<()>) -> Result<Response<()>, Status> {
Err(Status::permission_denied(""))
}

async fn server_stream(
&self,
_: Request<()>,
) -> Result<Response<Self::ServerStreamStream>, Status> {
Err(Status::permission_denied(""))
}

async fn client_stream(&self, _: Request<Streaming<()>>) -> Result<Response<()>, Status> {
Err(Status::permission_denied(""))
}

async fn bidirectional_stream(
&self,
_: Request<Streaming<()>>,
) -> Result<Response<Self::BidirectionalStreamStream>, Status> {
Err(Status::permission_denied(""))
}
}
3 changes: 2 additions & 1 deletion tonic-build/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,11 @@ quote = "1.0"
syn = "2.0"

[features]
default = ["transport", "prost"]
default = ["transport", "prost", "async_trait"]
prost = ["prost-build"]
cleanup-markdown = ["prost", "prost-build/cleanup-markdown"]
transport = []
async_trait = []

[package.metadata.docs.rs]
all-features = true
201 changes: 126 additions & 75 deletions tonic-build/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,15 @@ fn generate_trait<T: Service>(
service.name()
));

let async_trait = if cfg!(feature = "async_trait") {
quote!(#[async_trait])
} else {
TokenStream::new()
};

quote! {
#trait_doc
#[async_trait]
#async_trait
pub trait #server_trait : Send + Sync + 'static {
#methods
}
Expand Down Expand Up @@ -274,91 +280,136 @@ fn generate_trait_methods<T: Service>(
quote!(&self)
};

let method = match (
method.client_streaming(),
method.server_streaming(),
generate_default_stubs,
) {
(false, false, true) => {
quote! {
#method_doc
async fn #name(#self_param, request: tonic::Request<#req_message>)
-> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
Err(tonic::Status::unimplemented("Not yet implemented"))
let r#async = if cfg!(feature = "async_trait") {
quote!(async)
} else {
TokenStream::new()
};

let ret_type = |ret| {
if cfg!(feature = "async_trait") {
ret
} else {
let lifetime = if use_arc_self {
quote!('static)
} else {
quote!('_)
};
quote!(impl std::future::Future<Output = #ret> + Send + #lifetime)
}
};

let method = {
match (
method.client_streaming(),
method.server_streaming(),
generate_default_stubs,
) {
(false, false, true) => {
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<#res_message>, tonic::Status>),
);
quote! {
#method_doc
#r#async fn #name(#self_param, request: tonic::Request<#req_message>)
-> #return_type {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
}
}
}
(false, false, false) => {
quote! {
#method_doc
async fn #name(#self_param, request: tonic::Request<#req_message>)
-> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
(false, false, false) => {
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<#res_message>, tonic::Status>),
);
quote! {
#method_doc
#r#async fn #name(#self_param, request: tonic::Request<#req_message>)
-> #return_type;
}
}
}
(true, false, true) => {
quote! {
#method_doc
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> std::result::Result<tonic::Response<#res_message>, tonic::Status> {
Err(tonic::Status::unimplemented("Not yet implemented"))
(true, false, true) => {
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<#res_message>, tonic::Status>),
);
quote! {
#method_doc
#r#async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> #return_type {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
}
}
}
(true, false, false) => {
quote! {
#method_doc
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> std::result::Result<tonic::Response<#res_message>, tonic::Status>;
(true, false, false) => {
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<#res_message>, tonic::Status>),
);
quote! {
#method_doc
#r#async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> #return_type;
}
}
}
(false, true, true) => {
quote! {
#method_doc
async fn #name(#self_param, request: tonic::Request<#req_message>)
-> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
Err(tonic::Status::unimplemented("Not yet implemented"))
(false, true, true) => {
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status>),
);
quote! {
#method_doc
#r#async fn #name(#self_param, request: tonic::Request<#req_message>)
-> #return_type {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
}
}
}
(false, true, false) => {
let stream = quote::format_ident!("{}Stream", method.identifier());
let stream_doc = generate_doc_comment(format!(
" Server streaming response type for the {} method.",
method.identifier()
));

quote! {
#stream_doc
type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + Send + 'static;

#method_doc
async fn #name(#self_param, request: tonic::Request<#req_message>)
-> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
(false, true, false) => {
let stream = quote::format_ident!("{}Stream", method.identifier());
let stream_doc = generate_doc_comment(format!(
" Server streaming response type for the {} method.",
method.identifier()
));
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<Self::#stream>, tonic::Status>),
);

quote! {
#stream_doc
type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + Send + 'static;

#method_doc
#r#async fn #name(#self_param, request: tonic::Request<#req_message>)
-> #return_type;
}
}
}
(true, true, true) => {
quote! {
#method_doc
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status> {
Err(tonic::Status::unimplemented("Not yet implemented"))
(true, true, true) => {
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<BoxStream<#res_message>>, tonic::Status>),
);
quote! {
#method_doc
#r#async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> #return_type {
Err(tonic::Status::unimplemented("Not yet implemented"))
}
}
}
}
(true, true, false) => {
let stream = quote::format_ident!("{}Stream", method.identifier());
let stream_doc = generate_doc_comment(format!(
" Server streaming response type for the {} method.",
method.identifier()
));

quote! {
#stream_doc
type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + Send + 'static;

#method_doc
async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> std::result::Result<tonic::Response<Self::#stream>, tonic::Status>;
(true, true, false) => {
let stream = quote::format_ident!("{}Stream", method.identifier());
let stream_doc = generate_doc_comment(format!(
" Server streaming response type for the {} method.",
method.identifier()
));
let return_type = ret_type(
quote!(std::result::Result<tonic::Response<Self::#stream>, tonic::Status>),
);

quote! {
#stream_doc
type #stream: tonic::codegen::tokio_stream::Stream<Item = std::result::Result<#res_message, tonic::Status>> + Send + 'static;

#method_doc
#r#async fn #name(#self_param, request: tonic::Request<tonic::Streaming<#req_message>>)
-> #return_type;
}
}
}
};
Expand Down

0 comments on commit a34b259

Please sign in to comment.