Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
125 changes: 17 additions & 108 deletions juniper_hyper/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -8,8 +8,8 @@ use hyper::{
Body, Method, Request, Response, StatusCode,
};
use juniper::{
http::GraphQLRequest as JuniperGraphQLRequest, serde::Deserialize, DefaultScalarValue,
GraphQLType, GraphQLTypeAsync, InputValue, RootNode, ScalarValue,
http::{GraphQLBatchRequest, GraphQLRequest as JuniperGraphQLRequest},
GraphQLSubscriptionType, GraphQLType, GraphQLTypeAsync, InputValue, RootNode, ScalarValue,
};
use serde_json::error::Error as SerdeError;
use std::{error::Error, fmt, string::FromUtf8Error, sync::Arc};
Expand Down Expand Up @@ -61,7 +61,7 @@ where
CtxT: Send + Sync + 'static,
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync + 'static,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync + 'static,
SubscriptionT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync + 'static,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + Sync,
QueryT::TypeInfo: Send + Sync,
MutationT::TypeInfo: Send + Sync,
SubscriptionT::TypeInfo: Send + Sync,
Expand Down Expand Up @@ -89,10 +89,10 @@ where

fn parse_get_req<S: ScalarValue>(
req: Request<Body>,
) -> Result<GraphQLRequest<S>, GraphQLRequestError> {
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
req.uri()
.query()
.map(|q| gql_request_from_get(q).map(GraphQLRequest::Single))
.map(|q| gql_request_from_get(q).map(GraphQLBatchRequest::Single))
.unwrap_or_else(|| {
Err(GraphQLRequestError::Invalid(
"'query' parameter is missing".to_string(),
Expand All @@ -102,15 +102,16 @@ fn parse_get_req<S: ScalarValue>(

async fn parse_post_req<S: ScalarValue>(
body: Body,
) -> Result<GraphQLRequest<S>, GraphQLRequestError> {
) -> Result<GraphQLBatchRequest<S>, GraphQLRequestError> {
let chunk = hyper::body::to_bytes(body)
.await
.map_err(GraphQLRequestError::BodyHyper)?;

let input = String::from_utf8(chunk.iter().cloned().collect())
.map_err(GraphQLRequestError::BodyUtf8)?;

serde_json::from_str::<GraphQLRequest<S>>(&input).map_err(GraphQLRequestError::BodyJSONError)
serde_json::from_str::<GraphQLBatchRequest<S>>(&input)
.map_err(GraphQLRequestError::BodyJSONError)
}

pub async fn graphiql(graphql_endpoint: &str) -> Result<Response<Body>, hyper::Error> {
Expand Down Expand Up @@ -142,7 +143,7 @@ fn render_error(err: GraphQLRequestError) -> Response<Body> {
async fn execute_request<CtxT, QueryT, MutationT, SubscriptionT, S>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
request: GraphQLRequest<S>,
request: GraphQLBatchRequest<S>,
) -> Response<Body>
where
S: ScalarValue + Send + Sync + 'static,
Expand All @@ -154,8 +155,9 @@ where
MutationT::TypeInfo: Send + Sync,
SubscriptionT::TypeInfo: Send + Sync,
{
let (is_ok, body) = request.execute_sync(root_node, context);
let code = if is_ok {
let res = request.execute_sync(&*root_node, &context);
let body = Body::from(serde_json::to_string_pretty(&res).unwrap());
let code = if res.is_ok() {
StatusCode::OK
} else {
StatusCode::BAD_REQUEST
Expand All @@ -172,20 +174,21 @@ where
async fn execute_request_async<CtxT, QueryT, MutationT, SubscriptionT, S>(
root_node: Arc<RootNode<'static, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
request: GraphQLRequest<S>,
request: GraphQLBatchRequest<S>,
) -> Response<Body>
where
S: ScalarValue + Send + Sync + 'static,
CtxT: Send + Sync + 'static,
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync + 'static,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync + 'static,
SubscriptionT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync + 'static,
SubscriptionT: GraphQLSubscriptionType<S, Context = CtxT> + Send + Sync,
QueryT::TypeInfo: Send + Sync,
MutationT::TypeInfo: Send + Sync,
SubscriptionT::TypeInfo: Send + Sync,
{
let (is_ok, body) = request.execute(root_node, context).await;
let code = if is_ok {
let res = request.execute(&*root_node, &context).await;
let body = Body::from(serde_json::to_string_pretty(&res).unwrap());
let code = if res.is_ok() {
StatusCode::OK
} else {
StatusCode::BAD_REQUEST
Expand Down Expand Up @@ -263,100 +266,6 @@ fn new_html_response(code: StatusCode) -> Response<Body> {
resp
}

#[derive(serde_derive::Deserialize)]
#[serde(untagged)]
#[serde(bound = "InputValue<S>: Deserialize<'de>")]
enum GraphQLRequest<S = DefaultScalarValue>
where
S: ScalarValue,
{
Single(JuniperGraphQLRequest<S>),
Batch(Vec<JuniperGraphQLRequest<S>>),
}

impl<S> GraphQLRequest<S>
where
S: ScalarValue,
{
fn execute_sync<'a, CtxT: 'a, QueryT, MutationT, SubscriptionT>(
self,
root_node: Arc<RootNode<'a, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
) -> (bool, hyper::Body)
where
S: 'a + Send + Sync,
QueryT: GraphQLType<S, Context = CtxT> + 'a,
MutationT: GraphQLType<S, Context = CtxT> + 'a,
SubscriptionT: GraphQLType<S, Context = CtxT> + 'a,
{
match self {
GraphQLRequest::Single(request) => {
let res = request.execute_sync(&root_node, &context);
let is_ok = res.is_ok();
let body = Body::from(serde_json::to_string_pretty(&res).unwrap());
(is_ok, body)
}
GraphQLRequest::Batch(requests) => {
let results: Vec<_> = requests
.into_iter()
.map(move |request| {
let root_node = root_node.clone();
let res = request.execute_sync(&root_node, &context);
let is_ok = res.is_ok();
let body = serde_json::to_string_pretty(&res).unwrap();
(is_ok, body)
})
.collect();

let is_ok = !results.iter().any(|&(is_ok, _)| !is_ok);
let bodies: Vec<_> = results.into_iter().map(|(_, body)| body).collect();
let body = hyper::Body::from(format!("[{}]", bodies.join(",")));
(is_ok, body)
}
}
}

async fn execute<'a, CtxT: 'a, QueryT, MutationT, SubscriptionT>(
self,
root_node: Arc<RootNode<'a, QueryT, MutationT, SubscriptionT, S>>,
context: Arc<CtxT>,
) -> (bool, hyper::Body)
where
S: Send + Sync,
QueryT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
MutationT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
SubscriptionT: GraphQLTypeAsync<S, Context = CtxT> + Send + Sync,
QueryT::TypeInfo: Send + Sync,
MutationT::TypeInfo: Send + Sync,
SubscriptionT::TypeInfo: Send + Sync,
CtxT: Send + Sync,
{
match self {
GraphQLRequest::Single(request) => {
let res = request.execute(&*root_node, &context).await;
let is_ok = res.is_ok();
let body = Body::from(serde_json::to_string_pretty(&res).unwrap());
(is_ok, body)
}
GraphQLRequest::Batch(requests) => {
let futures = requests
.iter()
.map(|request| request.execute(&*root_node, &context))
.collect::<Vec<_>>();
let results = futures::future::join_all(futures).await;

let is_ok = results.iter().all(|res| res.is_ok());
let bodies: Vec<_> = results
.into_iter()
.map(|res| serde_json::to_string_pretty(&res).unwrap())
.collect();
let body = hyper::Body::from(format!("[{}]", bodies.join(",")));
(is_ok, body)
}
}
}
}

#[derive(Debug)]
enum GraphQLRequestError {
BodyHyper(hyper::Error),
Expand Down