diff --git a/docs/book/content/types/enums.md b/docs/book/content/types/enums.md index 14122b219..5e23341b3 100644 --- a/docs/book/content/types/enums.md +++ b/docs/book/content/types/enums.md @@ -54,3 +54,17 @@ enum StarWarsEpisode { # fn main() {} ``` + +## Supported Macro Attributes (Derive) + +| Name of Attribute | Container Support | Field Support | +|-------------------|:-----------------:|:----------------:| +| context | ✔ | ? | +| deprecated | ✔ | ✔ | +| description | ✔ | ✔ | +| interfaces | ? | ✘ | +| name | ✔ | ✔ | +| noasync | ✔ | ? | +| scalar | ✘ | ? | +| skip | ? | ✘ | +| ✔: supported | ✘: not supported | ?: not available | diff --git a/integration_tests/juniper_tests/src/codegen/derive_enum.rs b/integration_tests/juniper_tests/src/codegen/derive_enum.rs index 5dc1848ef..88a54a08b 100644 --- a/integration_tests/juniper_tests/src/codegen/derive_enum.rs +++ b/integration_tests/juniper_tests/src/codegen/derive_enum.rs @@ -4,6 +4,10 @@ use fnv::FnvHashMap; #[cfg(test)] use juniper::{self, DefaultScalarValue, FromInputValue, GraphQLType, InputValue, ToInputValue}; +pub struct CustomContext {} + +impl juniper::Context for CustomContext {} + #[derive(juniper::GraphQLEnum, Debug, PartialEq)] #[graphql(name = "Some", description = "enum descr")] enum SomeEnum { @@ -39,6 +43,12 @@ enum OverrideDocEnum { Foo, } +#[derive(juniper::GraphQLEnum)] +#[graphql(context = CustomContext, noasync)] +enum ContextEnum { + A, +} + #[test] fn test_derived_enum() { // Ensure that rename works. @@ -98,3 +108,16 @@ fn test_doc_comment_override() { let meta = OverrideDocEnum::meta(&(), &mut registry); assert_eq!(meta.description(), Some(&"enum override".to_string())); } + +fn test_context(_t: T) +where + T: GraphQLType, +{ + // empty +} + +#[test] +fn test_doc_custom_context() { + test_context(ContextEnum::A); + // test_context(OverrideDocEnum::Foo); does not work +} diff --git a/integration_tests/juniper_tests/src/codegen/derive_union.rs b/integration_tests/juniper_tests/src/codegen/derive_union.rs index 153857bf2..fcec2e74c 100644 --- a/integration_tests/juniper_tests/src/codegen/derive_union.rs +++ b/integration_tests/juniper_tests/src/codegen/derive_union.rs @@ -91,6 +91,13 @@ impl DroidCompat { } } +#[derive(juniper::GraphQLUnion)] +#[graphql(Context = CustomContext)] +pub enum DifferentContext { + A(DroidContext), + B(Droid), +} + // NOTICE: this can not compile due to generic implementation of GraphQLType<__S> // #[derive(juniper::GraphQLUnion)] // pub enum CharacterCompatFail { diff --git a/juniper/CHANGELOG.md b/juniper/CHANGELOG.md index c4da0799d..427c71a63 100644 --- a/juniper/CHANGELOG.md +++ b/juniper/CHANGELOG.md @@ -27,6 +27,8 @@ See [#569](https://github.com/graphql-rust/juniper/pull/569). See [#618](https://github.com/graphql-rust/juniper/pull/618). +- Derive macro `GraphQLEnum` supports custom context (see [#621](https://github.com/graphql-rust/juniper/pull/621)) + ## Breaking Changes - `juniper::graphiql` has moved to `juniper::http::graphiql` diff --git a/juniper_codegen/src/derive_enum.rs b/juniper_codegen/src/derive_enum.rs index e404e79e4..d98e60c03 100644 --- a/juniper_codegen/src/derive_enum.rs +++ b/juniper_codegen/src/derive_enum.rs @@ -1,288 +1,109 @@ use proc_macro2::TokenStream; +use crate::util; use quote::quote; -use syn::{self, Data, DeriveInput, Fields, Variant}; +use syn::{self, Data, Fields}; -use crate::util::*; - -#[derive(Default, Debug)] -struct EnumAttrs { - name: Option, - description: Option, -} - -impl EnumAttrs { - fn from_input(input: &DeriveInput) -> EnumAttrs { - let mut res = EnumAttrs { - name: None, - description: None, - }; - - // Check doc comments for description. - res.description = get_doc_comment(&input.attrs); - - // Check attributes for name and description. - if let Some(items) = get_graphql_attr(&input.attrs) { - for item in items { - if let Some(AttributeValue::String(val)) = - keyed_item_value(&item, "name", AttributeValidation::String) - { - if is_valid_name(&*val) { - res.name = Some(val); - continue; - } else { - panic!( - "Names must match /^[_a-zA-Z][_a-zA-Z0-9]*$/ but \"{}\" does not", - &*val - ); - } - } - if let Some(AttributeValue::String(val)) = - keyed_item_value(&item, "description", AttributeValidation::String) - { - res.description = Some(val); - continue; - } - panic!(format!( - "Unknown enum attribute for #[derive(GraphQLEnum)]: {:?}", - item - )); - } - } - res - } -} - -#[derive(Default)] -struct EnumVariantAttrs { - name: Option, - description: Option, - deprecation: Option, -} - -impl EnumVariantAttrs { - fn from_input(variant: &Variant) -> EnumVariantAttrs { - let mut res = EnumVariantAttrs::default(); - - // Check doc comments for description. - res.description = get_doc_comment(&variant.attrs); - - // Check builtin deprecated attribute for deprecation. - res.deprecation = get_deprecated(&variant.attrs); - - // Check attributes for name and description. - if let Some(items) = get_graphql_attr(&variant.attrs) { - for item in items { - if let Some(AttributeValue::String(val)) = - keyed_item_value(&item, "name", AttributeValidation::String) - { - if is_valid_name(&*val) { - res.name = Some(val); - continue; - } else { - panic!( - "Names must match /^[_a-zA-Z][_a-zA-Z0-9]*$/ but \"{}\" does not", - &*val - ); - } - } - if let Some(AttributeValue::String(val)) = - keyed_item_value(&item, "description", AttributeValidation::String) - { - res.description = Some(val); - continue; - } - if let Some(AttributeValue::String(val)) = - keyed_item_value(&item, "deprecation", AttributeValidation::String) - { - res.deprecation = Some(DeprecationAttr { reason: Some(val) }); - continue; - } - match keyed_item_value(&item, "deprecated", AttributeValidation::String) { - Some(AttributeValue::String(val)) => { - res.deprecation = Some(DeprecationAttr { reason: Some(val) }); - continue; - } - Some(AttributeValue::Bare) => { - res.deprecation = Some(DeprecationAttr { reason: None }); - continue; - } - None => {} - } - panic!(format!( - "Unknown variant attribute for #[derive(GraphQLEnum)]: {:?}", - item - )); - } - } - res +pub fn impl_enum(ast: syn::DeriveInput, is_internal: bool) -> TokenStream { + if !ast.generics.params.is_empty() { + panic!("#[derive(GraphQLEnum) does not support generics or lifetimes"); } -} - -pub fn impl_enum(ast: &syn::DeriveInput, is_internal: bool) -> TokenStream { - let juniper_path = if is_internal { - quote!(crate) - } else { - quote!(juniper) - }; let variants = match ast.data { - Data::Enum(ref enum_data) => enum_data.variants.iter().collect::>(), + Data::Enum(enum_data) => enum_data.variants, _ => { panic!("#[derive(GraphlQLEnum)] may only be applied to enums, not to structs"); } }; // Parse attributes. - let ident = &ast.ident; - let attrs = EnumAttrs::from_input(ast); - let name = attrs.name.unwrap_or_else(|| ast.ident.to_string()); - - let meta_description = match attrs.description { - Some(descr) => quote! { let meta = meta.description(#descr); }, - None => quote! { let meta = meta; }, + let attrs = match util::ObjectAttributes::from_attrs(&ast.attrs) { + Ok(a) => a, + Err(e) => { + panic!("Invalid #[graphql(...)] attribute: {}", e); + } }; - - let mut values = TokenStream::new(); - let mut resolves = TokenStream::new(); - let mut from_inputs = TokenStream::new(); - let mut to_inputs = TokenStream::new(); - - for variant in variants { - match variant.fields { - Fields::Unit => {} - _ => { - panic!(format!( - "Invalid enum variant {}.\nGraphQL enums may only contain unit variants.", - variant.ident - )); - } - }; - - let var_attrs = EnumVariantAttrs::from_input(variant); - let var_ident = &variant.ident; - - // Build value. - let name = var_attrs - .name - .unwrap_or_else(|| crate::util::to_upper_snake_case(&variant.ident.to_string())); - let descr = match var_attrs.description { - Some(s) => quote! { Some(#s.to_string()) }, - None => quote! { None }, - }; - let depr = match var_attrs.deprecation { - Some(DeprecationAttr { reason: Some(s) }) => quote! { - #juniper_path::meta::DeprecationStatus::Deprecated(Some(#s.to_string())) - }, - Some(DeprecationAttr { reason: None }) => quote! { - #juniper_path::meta::DeprecationStatus::Deprecated(None) - }, - None => quote! { - #juniper_path::meta::DeprecationStatus::Current - }, - }; - values.extend(quote! { - #juniper_path::meta::EnumValue{ - name: #name.to_string(), - description: #descr, - deprecation_status: #depr, - }, - }); - - // Build resolve match clause. - resolves.extend(quote! { - &#ident::#var_ident => #juniper_path::Value::scalar(String::from(#name)), - }); - - // Build from_input clause. - from_inputs.extend(quote! { - Some(#name) => Some(#ident::#var_ident), - }); - - // Build to_input clause. - to_inputs.extend(quote! { - &#ident::#var_ident => - #juniper_path::InputValue::scalar(#name.to_string()), - }); + if !attrs.interfaces.is_empty() { + panic!("Invalid #[graphql(...)] attribute 'interfaces': #[derive(GraphQLEnum) does not support 'interfaces'"); + } + if attrs.scalar.is_some() { + panic!("Invalid #[graphql(...)] attribute 'scalar': #[derive(GraphQLEnum) does not support explicit scalars"); } - let _async = quote!( - impl<__S> #juniper_path::GraphQLTypeAsync<__S> for #ident - where - __S: #juniper_path::ScalarValue + Send + Sync, - { - fn resolve_async<'a>( - &'a self, - info: &'a Self::TypeInfo, - selection_set: Option<&'a [#juniper_path::Selection<__S>]>, - executor: &'a #juniper_path::Executor, - ) -> #juniper_path::BoxFuture<'a, #juniper_path::ExecutionResult<__S>> { - use #juniper_path::GraphQLType; - use futures::future; - let v = self.resolve(info, selection_set, executor); - future::FutureExt::boxed(future::ready(v)) - } - } - ); - - let body = quote! { - impl<__S> #juniper_path::GraphQLType<__S> for #ident - where __S: - #juniper_path::ScalarValue, - { - type Context = (); - type TypeInfo = (); - - fn name(_: &()) -> Option<&'static str> { - Some(#name) - } + // Parse attributes. + let ident = &ast.ident; + let name = attrs.name.unwrap_or_else(|| ident.to_string()); + + let mut mapping = std::collections::HashMap::new(); + + let fields = variants + .into_iter() + .filter_map(|field| { + let field_attrs = match util::FieldAttributes::from_attrs( + field.attrs, + util::FieldAttributeParseMode::Object, + ) { + Ok(attrs) => attrs, + Err(e) => panic!("Invalid #[graphql] attribute for field: \n{}", e), + }; + + if field_attrs.skip { + panic!("#[derive(GraphQLEnum)] does not support #[graphql(skip)] on fields"); + } else { + let field_name = field.ident; + let name = field_attrs + .name + .clone() + .unwrap_or_else(|| util::to_upper_snake_case(&field_name.to_string())); + + match mapping.get(&name) { + Some(other_field_name) => + panic!(format!("#[derive(GraphQLEnum)] all variants needs to be unique. Another field name `{}` has the same identifier `{}`, thus `{}` can not be named `{}`. One of the fields is manually renamed!", other_field_name, name, field_name, name)), + None => { + mapping.insert(name.clone(), field_name.clone()); + } + } - fn meta<'r>(_: &(), registry: &mut #juniper_path::Registry<'r, __S>) - -> #juniper_path::meta::MetaType<'r, __S> - where __S: 'r, - { - let meta = registry.build_enum_type::<#ident>(&(), &[ - #values - ]); - #meta_description - meta.into_meta() - } + let resolver_code = quote!( #ident::#field_name ); - fn resolve( - &self, - _: &(), - _: Option<&[#juniper_path::Selection<__S>]>, - _: &#juniper_path::Executor - ) -> #juniper_path::ExecutionResult<__S> { - let v = match self { - #resolves + let _type = match field.fields { + Fields::Unit => syn::parse_str(&field_name.to_string()).unwrap(), + _ => panic!("#[derive(GraphQLEnum)] all fields of the enum must be unnamed"), }; - Ok(v) - } - } - impl<__S: #juniper_path::ScalarValue> #juniper_path::FromInputValue<__S> for #ident { - fn from_input_value(v: &#juniper_path::InputValue<__S>) -> Option<#ident> - { - match v.as_enum_value().or_else(|| { - v.as_string_value() - }) { - #from_inputs - _ => None, - } + Some(util::GraphQLTypeDefinitionField { + name, + _type, + args: Vec::new(), + description: field_attrs.description, + deprecation: field_attrs.deprecation, + resolver_code, + is_type_inferred: true, + is_async: false, + }) } - } + }) + .collect::>(); - impl<__S: #juniper_path::ScalarValue> #juniper_path::ToInputValue<__S> for #ident { - fn to_input_value(&self) -> #juniper_path::InputValue<__S> { - match self { - #to_inputs - } - } - } + if fields.len() == 0 { + panic!("#[derive(GraphQLEnum)] requires at least one variants"); + } - #_async + let definition = util::GraphQLTypeDefiniton { + name, + _type: syn::parse_str(&ast.ident.to_string()).unwrap(), + context: attrs.context, + scalar: None, + description: attrs.description, + fields, + // NOTICE: only unit variants allow -> no generics possible + generics: syn::Generics::default(), + interfaces: None, + include_type_generics: true, + generic_scalar: true, + no_async: attrs.no_async, }; - body + let juniper_crate_name = if is_internal { "crate" } else { "juniper" }; + definition.into_enum_tokens(juniper_crate_name) } diff --git a/juniper_codegen/src/lib.rs b/juniper_codegen/src/lib.rs index 923519a45..9a21c4578 100644 --- a/juniper_codegen/src/lib.rs +++ b/juniper_codegen/src/lib.rs @@ -24,7 +24,7 @@ use proc_macro::TokenStream; #[proc_macro_derive(GraphQLEnum, attributes(graphql))] pub fn derive_enum(input: TokenStream) -> TokenStream { let ast = syn::parse::(input).unwrap(); - let gen = derive_enum::impl_enum(&ast, false); + let gen = derive_enum::impl_enum(ast, false); gen.into() } @@ -32,7 +32,7 @@ pub fn derive_enum(input: TokenStream) -> TokenStream { #[doc(hidden)] pub fn derive_enum_internal(input: TokenStream) -> TokenStream { let ast = syn::parse::(input).unwrap(); - let gen = derive_enum::impl_enum(&ast, true); + let gen = derive_enum::impl_enum(ast, true); gen.into() } diff --git a/juniper_codegen/src/util/mod.rs b/juniper_codegen/src/util/mod.rs index b55a3ceae..eac3e6559 100644 --- a/juniper_codegen/src/util/mod.rs +++ b/juniper_codegen/src/util/mod.rs @@ -1329,7 +1329,16 @@ impl GraphQLTypeDefiniton { quote! { if type_name == (<#var_ty as #juniper_crate_name::GraphQLType<#scalar>>::name(&())).unwrap() { - return executor.resolve(&(), &{ #expr }); + return #juniper_crate_name::IntoResolvable::into( + { #expr }, + executor.context() + ) + .and_then(|res| { + match res { + Some((ctx, r)) => executor.replaced_context(ctx).resolve_with_ctx(&(), &r), + None => Ok(#juniper_crate_name::Value::null()), + } + }); } } }); @@ -1339,8 +1348,20 @@ impl GraphQLTypeDefiniton { quote! { if type_name == (<#var_ty as #juniper_crate_name::GraphQLType<#scalar>>::name(&())).unwrap() { + let inner_res = #juniper_crate_name::IntoResolvable::into( + { #expr }, + executor.context() + ); + let f = async move { - executor.resolve_async(&(), &{ #expr }).await + match inner_res { + Ok(Some((ctx, r))) => { + let subexec = executor.replaced_context(ctx); + subexec.resolve_with_ctx_async(&(), &r).await + }, + Ok(None) => Ok(#juniper_crate_name::Value::null()), + Err(e) => Err(e), + } }; use futures::future; return future::FutureExt::boxed(f); @@ -1460,6 +1481,201 @@ impl GraphQLTypeDefiniton { type_impl } + + pub fn into_enum_tokens(self, juniper_crate_name: &str) -> proc_macro2::TokenStream { + let juniper_crate_name = syn::parse_str::(juniper_crate_name).unwrap(); + + let name = &self.name; + let ty = &self._type; + let context = self + .context + .as_ref() + .map(|ctx| quote!( #ctx )) + .unwrap_or_else(|| quote!(())); + + let scalar = self + .scalar + .as_ref() + .map(|s| quote!( #s )) + .unwrap_or_else(|| { + if self.generic_scalar { + // If generic_scalar is true, we always insert a generic scalar. + // See more comments below. + quote!(__S) + } else { + quote!(#juniper_crate_name::DefaultScalarValue) + } + }); + + let description = self + .description + .as_ref() + .map(|description| quote!( .description(#description) )); + + let values = self.fields.iter().map(|variant| { + let variant_name = &variant.name; + + let descr = variant + .description + .as_ref() + .map(|description| quote!(Some(#description.to_string()))) + .unwrap_or_else(|| quote!(None)); + + let depr = variant + .deprecation + .as_ref() + .map(|deprecation| match deprecation.reason.as_ref() { + Some(reason) => quote!( #juniper_crate_name::meta::DeprecationStatus::Deprecated(Some(#reason.to_string())) ), + None => quote!( #juniper_crate_name::meta::DeprecationStatus::Deprecated(None) ), + }) + .unwrap_or_else(|| quote!(#juniper_crate_name::meta::DeprecationStatus::Current)); + + quote!( + #juniper_crate_name::meta::EnumValue { + name: #variant_name.to_string(), + description: #descr, + deprecation_status: #depr, + }, + ) + }); + + let resolves = self.fields.iter().map(|variant| { + let variant_name = &variant.name; + let resolver_code = &variant.resolver_code; + + quote!( + &#resolver_code => #juniper_crate_name::Value::scalar(String::from(#variant_name)), + ) + }); + + let from_inputs = self.fields.iter().map(|variant| { + let variant_name = &variant.name; + let resolver_code = &variant.resolver_code; + + quote!( + Some(#variant_name) => Some(#resolver_code), + ) + }); + + let to_inputs = self.fields.iter().map(|variant| { + let variant_name = &variant.name; + let resolver_code = &variant.resolver_code; + + quote!( + &#resolver_code => + #juniper_crate_name::InputValue::scalar(#variant_name.to_string()), + ) + }); + + let mut generics = self.generics.clone(); + + if self.scalar.is_none() && self.generic_scalar { + // No custom scalar specified, but always generic specified. + // Therefore we inject the generic scalar. + + generics.params.push(parse_quote!(__S)); + + let where_clause = generics.where_clause.get_or_insert(parse_quote!(where)); + // Insert ScalarValue constraint. + where_clause + .predicates + .push(parse_quote!(__S: #juniper_crate_name::ScalarValue)); + } + + let (impl_generics, _, where_clause) = generics.split_for_impl(); + + let mut where_async = where_clause.cloned().unwrap_or_else(|| parse_quote!(where)); + where_async + .predicates + .push(parse_quote!( #scalar: Send + Sync )); + where_async.predicates.push(parse_quote!(Self: Send + Sync)); + + let _async = quote!( + impl#impl_generics #juniper_crate_name::GraphQLTypeAsync<#scalar> for #ty + #where_async + { + fn resolve_async<'a>( + &'a self, + info: &'a Self::TypeInfo, + selection_set: Option<&'a [#juniper_crate_name::Selection<#scalar>]>, + executor: &'a #juniper_crate_name::Executor, + ) -> #juniper_crate_name::BoxFuture<'a, #juniper_crate_name::ExecutionResult<#scalar>> { + use #juniper_crate_name::GraphQLType; + use futures::future; + let v = self.resolve(info, selection_set, executor); + future::FutureExt::boxed(future::ready(v)) + } + } + ); + + let mut body = quote!( + impl#impl_generics #juniper_crate_name::GraphQLType<#scalar> for #ty + #where_clause + { + type Context = #context; + type TypeInfo = (); + + fn name(_: &()) -> Option<&'static str> { + Some(#name) + } + + fn meta<'r>( + _: &(), + registry: &mut #juniper_crate_name::Registry<'r, #scalar> + ) -> #juniper_crate_name::meta::MetaType<'r, #scalar> + where #scalar: 'r, + { + registry.build_enum_type::<#ty>(&(), &[ + #( #values )* + ]) + #description + .into_meta() + } + + fn resolve( + &self, + _: &(), + _: Option<&[#juniper_crate_name::Selection<#scalar>]>, + _: &#juniper_crate_name::Executor + ) -> #juniper_crate_name::ExecutionResult<#scalar> { + let v = match self { + #( #resolves )* + }; + Ok(v) + } + } + + impl#impl_generics #juniper_crate_name::FromInputValue<#scalar> for #ty + #where_clause + { + fn from_input_value(v: &#juniper_crate_name::InputValue<#scalar>) -> Option<#ty> + { + match v.as_enum_value().or_else(|| { + v.as_string_value() + }) { + #( #from_inputs )* + _ => None, + } + } + } + + impl#impl_generics #juniper_crate_name::ToInputValue<#scalar> for #ty + #where_clause + { + fn to_input_value(&self) -> #juniper_crate_name::InputValue<#scalar> { + match self { + #( #to_inputs )* + } + } + } + ); + + if !self.no_async { + body.extend(_async) + } + + body + } } #[cfg(test)]