Skip to content

Commit

Permalink
Switch to Extensions as parameter
Browse files Browse the repository at this point in the history
  • Loading branch information
nickpresta committed Mar 22, 2024
1 parent e781370 commit cb2f260
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 109 deletions.
13 changes: 8 additions & 5 deletions crates/twirp-build/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ pub struct ServiceGenerator;

impl prost_build::ServiceGenerator for ServiceGenerator {
fn finalize_package(&mut self, _package: &str, buf: &mut String) {
buf.insert_str(0, "use twirp::server::{Request, Response};\n");
buf.insert_str(0, "use twirp::Extensions;\n");
}

fn generate(&mut self, service: prost_build::Service, buf: &mut String) {
Expand All @@ -33,7 +33,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
for m in &service.methods {
writeln!(
buf,
" async fn {}(&self, req: Request<{}>) -> Result<Response<{}>, twirp::TwirpErrorResponse>;",
" async fn {}(self: std::sync::Arc<Self>, extensions: &mut Extensions, req: {}) -> Result<{}, twirp::TwirpErrorResponse>;",
m.name, m.input_type, m.output_type,
)
.unwrap();
Expand All @@ -47,7 +47,7 @@ impl prost_build::ServiceGenerator for ServiceGenerator {
where
T: {service_name} + Send + Sync + 'static,
{{
twirp::details::TwirpRouterBuilder::new(api)"#,
twirp::details::TwirpRouterBuilder::new(api.clone())"#,
)
.unwrap();
for m in &service.methods {
Expand All @@ -56,8 +56,11 @@ where
let rust_method_name = &m.name;
writeln!(
buf,
r#" .route("/{uri}", |api: std::sync::Arc<T>, req: Request<{req_type}>| async move {{
api.{rust_method_name}(req).await
r#" .route("/{uri}", move |api: std::sync::Arc<T>, extensions: &mut Extensions, req: {req_type}| {{
async move {{
let api = api.clone();
api.{rust_method_name}(extensions, req).await
}}
}})"#,
)
.unwrap();
Expand Down
8 changes: 4 additions & 4 deletions crates/twirp/src/details.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ use std::future::Future;
use axum::extract::{Request, State};
use axum::Router;

use crate::{server, TwirpErrorResponse};
use crate::{server, Extensions, TwirpErrorResponse};

/// Builder object used by generated code to build a Twirp service.
///
Expand All @@ -30,11 +30,11 @@ where
/// Add a handler for an `rpc` to the router.
///
/// The generated code passes a closure that calls the method, like
/// `|api: Arc<HaberdasherApiServer>, req: MakeHatRequest| async move { api.make_hat(req) }`.
/// `|api: Arc<HaberdasherApiServer>, extensions: &mut Extensions, req: MakeHatRequest| async move { api.make_hat(extensions, req) }`.
pub fn route<F, Fut, Req, Res>(self, url: &str, f: F) -> Self
where
F: Fn(S, server::Request<Req>) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<server::Response<Res>, TwirpErrorResponse>> + Send,
F: Fn(S, &mut Extensions, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Res, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Res: prost::Message + serde::Serialize,
{
Expand Down
2 changes: 1 addition & 1 deletion crates/twirp/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@ pub mod details;

pub use client::{Client, ClientBuilder, ClientError, Middleware, Next, Result};
pub use error::*; // many constructors like `invalid_argument()`
pub use server::{Request, Response};
pub use server::Extensions;

// Re-export this crate's dependencies that users are likely to code against. These can be used to
// import the exact versions of these libraries `twirp` is built with -- useful if your project is
Expand Down
103 changes: 21 additions & 82 deletions crates/twirp/src/server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,8 @@ use std::fmt::Debug;
use axum::body::Body;
use axum::response::IntoResponse;
use futures::Future;
use http::{Extensions, HeaderMap};
pub use http::Extensions;
use http::HeaderMap;
use http_body_util::BodyExt;
use serde::de::DeserializeOwned;
use serde::Serialize;
Expand All @@ -17,68 +18,6 @@ use tokio::time::{Duration, Instant};
use crate::headers::{CONTENT_TYPE_JSON, CONTENT_TYPE_PROTOBUF};
use crate::{error, serialize_proto_message, GenericError, TwirpErrorResponse};

#[derive(Debug, Clone)]
pub struct Request<T> {
message: T,
extensions: Extensions,
}

#[derive(Debug, Clone)]
pub struct Response<T> {
message: T,
extensions: Extensions,
}

impl<T> Request<T> {
pub fn new(message: T) -> Self {
Request {
message,
extensions: Extensions::new(),
}
}

pub fn into_inner(self) -> T {
self.message
}

pub fn into_parts(self) -> (T, Extensions) {
(self.message, self.extensions)
}

pub fn extensions(&self) -> &Extensions {
&self.extensions
}

pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}

impl<T> Response<T> {
pub fn new(message: T) -> Self {
Response {
message,
extensions: Extensions::new(),
}
}

pub fn into_inner(self) -> T {
self.message
}

pub fn into_parts(self) -> (T, Extensions) {
(self.message, self.extensions)
}

pub fn extensions(&self) -> &Extensions {
&self.extensions
}

pub fn extensions_mut(&mut self) -> &mut Extensions {
&mut self.extensions
}
}

// TODO: Properly implement JsonPb (de)serialization as it is slightly different
// than standard JSON.
#[derive(Debug, Clone, Copy, Default)]
Expand Down Expand Up @@ -107,18 +46,20 @@ pub(crate) async fn handle_request<S, F, Fut, Req, Resp>(
f: F,
) -> hyper::Response<Body>
where
F: FnOnce(S, Request<Req>) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Response<Resp>, TwirpErrorResponse>> + Send,
F: FnOnce(S, &mut Extensions, Req) -> Fut + Clone + Sync + Send + 'static,
Fut: Future<Output = Result<Resp, TwirpErrorResponse>> + Send,
Req: prost::Message + Default + serde::de::DeserializeOwned,
Resp: prost::Message + serde::Serialize,
{
let mut timings = req
.extensions()
let (mut parts, body) = req.into_parts();
let extensions = &mut parts.extensions;
let mut timings = extensions
.get::<Timings>()
.copied()
.unwrap_or_else(|| Timings::new(Instant::now()));

let (req, resp_fmt) = match parse_request(req, &mut timings).await {
let headers = &parts.headers;
let (message, resp_fmt) = match parse_request(body, headers, &mut timings).await {
Ok(pair) => pair,
Err(err) => {
// This is the only place we use tracing (would be nice to remove)
Expand All @@ -131,10 +72,10 @@ where
}
};

let res = f(service, req).await;
let res = f(service, extensions, message).await;
timings.set_response_handled();

let mut resp = match write_response(res, resp_fmt) {
let mut resp = match write_response(res, extensions, resp_fmt) {
Ok(resp) => resp,
Err(err) => {
let mut twirp_err = error::unknown("error serializing response");
Expand All @@ -144,19 +85,19 @@ where
};
timings.set_response_written();

// resp.extensions_mut().extend(extensions);
resp.extensions_mut().insert(timings);
resp
}

async fn parse_request<T>(
req: hyper::Request<Body>,
body: Body,
headers: &HeaderMap,
timings: &mut Timings,
) -> Result<(Request<T>, BodyFormat), GenericError>
) -> Result<(T, BodyFormat), GenericError>
where
T: prost::Message + Default + DeserializeOwned,
{
let (parts, body) = req.into_parts();
let headers = &parts.headers;
let format = BodyFormat::from_content_type(headers);
let bytes = body.collect().await?.to_bytes();
timings.set_received();
Expand All @@ -166,32 +107,30 @@ where
};
timings.set_parsed();

let mut request = Request::new(message);
request.extensions_mut().extend(parts.extensions);
Ok((request, format))
Ok((message, format))
}

fn write_response<T>(
response: Result<Response<T>, TwirpErrorResponse>,
response: Result<T, TwirpErrorResponse>,
response_extensions: &Extensions,
response_format: BodyFormat,
) -> Result<hyper::Response<Body>, GenericError>
where
T: prost::Message + Serialize,
{
let res = match response {
Ok(response) => {
let (message, response_extensions) = response.into_parts();
let mut builder = hyper::Response::builder();
if let Some(extensions_mut) = builder.extensions_mut() {
extensions_mut.extend(response_extensions);
extensions_mut.extend(response_extensions.clone());
}

match response_format {
BodyFormat::Pb => builder
.header(hyper::header::CONTENT_TYPE, CONTENT_TYPE_PROTOBUF)
.body(Body::from(serialize_proto_message(message)))?,
.body(Body::from(serialize_proto_message(response)))?,
BodyFormat::JsonPb => {
let data = serde_json::to_string(&message)?;
let data = serde_json::to_string(&response)?;
builder
.header(hyper::header::CONTENT_TYPE, CONTENT_TYPE_JSON)
.body(Body::from(data))?
Expand Down
32 changes: 15 additions & 17 deletions example/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ use twirp::axum::{
response::Response,
routing::get,
};
use twirp::{invalid_argument, server, Router, TwirpErrorResponse};
use twirp::{invalid_argument, Extensions, Router, TwirpErrorResponse};

pub mod service {
pub mod haberdash {
Expand Down Expand Up @@ -49,10 +49,10 @@ struct HaberdasherApiServer;
#[async_trait]
impl haberdash::HaberdasherApi for HaberdasherApiServer {
async fn make_hat(
&self,
request: server::Request<MakeHatRequest>,
) -> Result<server::Response<MakeHatResponse>, TwirpErrorResponse> {
let (req, extensions) = request.into_parts();
self: std::sync::Arc<Self>,
extensions: &mut Extensions,
req: MakeHatRequest,
) -> Result<MakeHatResponse, TwirpErrorResponse> {
let value = extensions.get::<MyMiddlewareValue>();
println!("got request extension value: {:?}", value);

Expand All @@ -65,16 +65,16 @@ impl haberdash::HaberdasherApi for HaberdasherApiServer {
.duration_since(UNIX_EPOCH)
.unwrap_or_default();

let mut response = server::Response::new(MakeHatResponse {
let response = MakeHatResponse {
color: "black".to_string(),
name: "top hat".to_string(),
size: req.inches,
timestamp: Some(prost_wkt_types::Timestamp {
seconds: ts.as_secs() as i64,
nanos: 0,
}),
});
response.extensions_mut().insert(MyMiddlewareValue {
};
extensions.insert(MyMiddlewareValue {
value: value.map_or_else(|| 0, |f| f.value) + 1,
});
Ok(response)
Expand Down Expand Up @@ -119,21 +119,19 @@ mod test {

#[tokio::test]
async fn success() {
let api = HaberdasherApiServer {};
let res = api
.make_hat(server::Request::new(MakeHatRequest { inches: 1 }))
.await;
let api = std::sync::Arc::new(HaberdasherApiServer {});
let extensions = &mut Extensions::new();
let res = api.make_hat(extensions, MakeHatRequest { inches: 1 }).await;
assert!(res.is_ok());
let res = res.unwrap().into_inner();
let res = res.unwrap();
assert_eq!(res.size, 1);
}

#[tokio::test]
async fn invalid_request() {
let api = HaberdasherApiServer {};
let res = api
.make_hat(server::Request::new(MakeHatRequest { inches: 0 }))
.await;
let api = std::sync::Arc::new(HaberdasherApiServer {});
let extensions = &mut Extensions::new();
let res = api.make_hat(extensions, MakeHatRequest { inches: 0 }).await;
assert!(res.is_err());
let err = res.unwrap_err();
assert_eq!(err.code, TwirpErrorCode::InvalidArgument);
Expand Down

0 comments on commit cb2f260

Please sign in to comment.