Skip to content
Permalink
Browse files

Add support for batched resolver methods (fixes #12)

  • Loading branch information...
apottere committed Nov 23, 2017
1 parent 516da8c commit 73d42e746ffdb55575b9e4d839ffc41fed70d99a
@@ -7,9 +7,8 @@ import graphql.schema.DataFetchingEnvironment
import org.apache.commons.lang3.ClassUtils
import org.apache.commons.lang3.reflect.FieldUtils
import org.slf4j.LoggerFactory
import java.lang.reflect.Field
import java.lang.reflect.Method
import java.lang.reflect.Modifier
import java.lang.reflect.ParameterizedType

/**
* @author Andrew Potter
@@ -29,15 +28,15 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
val scanProperties = field.inputValueDefinitions.isEmpty()
val found = searches.mapNotNull { search -> findFieldResolver(field, search, scanProperties) }

if(resolverInfo is RootResolverInfo && found.size > 1) {
if (resolverInfo is RootResolverInfo && found.size > 1) {
throw FieldResolverError("Found more than one matching resolver for field '$field': $found")
}

return found.firstOrNull() ?: missingFieldResolver(field, searches, scanProperties)
return found.firstOrNull() ?: missingFieldResolver(field, searches, scanProperties)
}

private fun missingFieldResolver(field: FieldDefinition, searches: List<Search>, scanProperties: Boolean): FieldResolver {
return if(options.allowUnimplementedResolvers) {
return if (options.allowUnimplementedResolvers) {
log.warn("Missing resolver for field: $field")

MissingFieldResolver(field, options)
@@ -48,13 +47,13 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {

private fun findFieldResolver(field: FieldDefinition, search: Search, scanProperties: Boolean): FieldResolver? {
val method = findResolverMethod(field, search)
if(method != null) {
if (method != null) {
return MethodFieldResolver(field, search, options, method.apply { isAccessible = true })
}

if(scanProperties) {
if (scanProperties) {
val property = findResolverProperty(field, search)
if(property != null) {
if (property != null) {
return PropertyFieldResolver(field, search, options, property.apply { isAccessible = true })
}
}
@@ -67,7 +66,7 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
private fun findResolverMethod(field: FieldDefinition, search: Search): java.lang.reflect.Method? {

val methods = getAllMethods(search.type)
val argumentCount = field.inputValueDefinitions.size + if(search.requiredFirstParameterType != null) 1 else 0
val argumentCount = field.inputValueDefinitions.size + if (search.requiredFirstParameterType != null) 1 else 0
val name = field.name

val isBoolean = isBoolean(field.type)
@@ -86,11 +85,38 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
}

private fun verifyMethodArguments(method: java.lang.reflect.Method, requiredCount: Int, search: Search): Boolean {
val appropriateFirstParameter = if (search.requiredFirstParameterType != null) {
if(MethodFieldResolver.isBatched(method, search)) {
verifyBatchedMethodFirstArgument(method.genericParameterTypes.firstOrNull(), search.requiredFirstParameterType)
} else {
method.parameterTypes.firstOrNull() == search.requiredFirstParameterType
}
} else {
true
}

val correctParameterCount = method.parameterCount == requiredCount || (method.parameterCount == (requiredCount + 1) && method.parameterTypes.last() == DataFetchingEnvironment::class.java)
val appropriateFirstParameter = if(search.requiredFirstParameterType != null) method.parameterTypes.firstOrNull() == search.requiredFirstParameterType else true
return correctParameterCount && appropriateFirstParameter
}

private fun verifyBatchedMethodFirstArgument(firstType: JavaType?, requiredFirstParameterType: Class<*>?): Boolean {
if(firstType == null) {
return false
}

if(firstType !is ParameterizedType) {
return false
}

if(!TypeClassMatcher.isListType(firstType, GenericType(firstType, options))) {
return false
}

val typeArgument = firstType.actualTypeArguments.first() as? Class<*> ?: return false

return typeArgument == requiredFirstParameterType
}

private fun findResolverProperty(field: FieldDefinition, search: Search) =
FieldUtils.getAllFields(search.type).find { it.name == field.name }

@@ -102,16 +128,16 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
signatures.addAll(getMissingMethodSignatures(field, search, isBoolean, scannedProperties))
}

return "No method${if(scannedProperties) " or field" else ""} found with any of the following signatures (with or without ${DataFetchingEnvironment::class.java.name} as the last argument), in priority order:\n${signatures.joinToString("\n ")}"
return "No method${if (scannedProperties) " or field" else ""} found with any of the following signatures (with or without ${DataFetchingEnvironment::class.java.name} as the last argument), in priority order:\n${signatures.joinToString("\n ")}"
}

private fun getMissingMethodSignatures(field: FieldDefinition, search: Search, isBoolean: Boolean, scannedProperties: Boolean): List<String> {
private fun getMissingMethodSignatures(field: FieldDefinition, search: Search, isBoolean: Boolean, scannedProperties: Boolean): List<String> {
val baseType = search.type
val signatures = mutableListOf<String>()
val args = mutableListOf<String>()
val sep = ", "

if(search.requiredFirstParameterType != null) {
if (search.requiredFirstParameterType != null) {
args.add(search.requiredFirstParameterType.name)
}

@@ -120,18 +146,18 @@ internal class FieldResolverScanner(val options: SchemaParserOptions) {
val argString = args.joinToString(sep)

signatures.add("${baseType.name}.${field.name}($argString)")
if(isBoolean) {
if (isBoolean) {
signatures.add("${baseType.name}.is${field.name.capitalize()}($argString)")
}
signatures.add("${baseType.name}.get${field.name.capitalize()}($argString)")
if(scannedProperties) {
if (scannedProperties) {
signatures.add("${baseType.name}.${field.name}")
}

return signatures
}

data class Search(val type: Class<*>, val resolverInfo: ResolverInfo, val source: Any?, val requiredFirstParameterType: Class<*>? = null)
data class Search(val type: Class<*>, val resolverInfo: ResolverInfo, val source: Any?, val requiredFirstParameterType: Class<*>? = null, val allowBatched: Boolean = false)
}

class FieldResolverError(msg: String): RuntimeException(msg)
@@ -19,20 +19,31 @@ import java.util.Optional
internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolverScanner.Search, options: SchemaParserOptions, val method: Method): FieldResolver(field, search, options, method.declaringClass) {

companion object {
internal fun isBatched(method: Method) = method.getAnnotation(Batched::class.java) != null
fun isBatched(method: Method, search: FieldResolverScanner.Search): Boolean {
if(method.getAnnotation(Batched::class.java) != null) {
if(!search.allowBatched) {
throw ResolverError("The @Batched annotation is only allowed on non-root resolver methods, but it was found on ${search.type.name}#${method.name}!")
}

return true
}

return false
}
}

private val dataFetchingEnvironment = method.parameterCount == (field.inputValueDefinitions.size + getIndexOffset() + 1)

override fun createDataFetcher(): DataFetcher<*> {
val batched = isBatched(method, search)
val args = mutableListOf<ArgumentPlaceholder>()
val mapper = ObjectMapper().apply {
options.objectMapperConfigurer.configure(this, ObjectMapperConfigurerContext(field))
}.registerModule(Jdk8Module()).registerKotlinModule()

// Add source argument if this is a resolver (but not a root resolver)
if(this.search.requiredFirstParameterType != null) {
val expectedType = this.search.requiredFirstParameterType
val expectedType = if(batched) Iterable::class.java else this.search.requiredFirstParameterType

args.add({ environment ->
val source = environment.getSource<Any>()
@@ -76,15 +87,15 @@ internal class MethodFieldResolver(field: FieldDefinition, search: FieldResolver
args.add({ environment -> environment })
}

if(isBatched(method)) {
return BatchedMethodFieldResolverDataFetcher(getSourceResolver(), this.method, args)
return if(batched) {
BatchedMethodFieldResolverDataFetcher(getSourceResolver(), this.method, args)
} else {
return MethodFieldResolverDataFetcher(getSourceResolver(), this.method, args)
MethodFieldResolverDataFetcher(getSourceResolver(), this.method, args)
}
}

override fun scanForMatches(): List<TypeClassMatcher.PotentialMatch> {
val batched = MethodFieldResolver.isBatched(method)
val batched = isBatched(method, search)
val returnValueMatch = TypeClassMatcher.PotentialMatch.returnValue(field.type, method.genericReturnType, genericType, SchemaClassScanner.ReturnValueReference(method), batched)

return field.inputValueDefinitions.mapIndexed { i, inputDefinition ->
@@ -6,9 +6,8 @@ import java.lang.reflect.ParameterizedType
internal abstract class ResolverInfo {
abstract fun getFieldSearches(): List<FieldResolverScanner.Search>

protected fun getRealResolverClass(resolver: GraphQLResolver<*>, options: SchemaParserOptions): Class<*> {
return options.proxyHandlers.find { it.canHandle(resolver) }?.getTargetClass(resolver) ?: resolver.javaClass
}
protected fun getRealResolverClass(resolver: GraphQLResolver<*>, options: SchemaParserOptions) =
options.proxyHandlers.find { it.canHandle(resolver) }?.getTargetClass(resolver) ?: resolver.javaClass
}

internal class NormalResolverInfo(val resolver: GraphQLResolver<*>, private val options: SchemaParserOptions): ResolverInfo() {
@@ -37,22 +36,20 @@ internal class NormalResolverInfo(val resolver: GraphQLResolver<*>, private val

override fun getFieldSearches(): List<FieldResolverScanner.Search> {
return listOf(
FieldResolverScanner.Search(resolverType, this, resolver, dataClassType),
FieldResolverScanner.Search(resolverType, this, resolver, dataClassType, true),
FieldResolverScanner.Search(dataClassType, this, null)
)
}
}

internal class RootResolverInfo(val resolvers: List<GraphQLRootResolver>, private val options: SchemaParserOptions): ResolverInfo() {
override fun getFieldSearches(): List<FieldResolverScanner.Search> {
return resolvers.map { FieldResolverScanner.Search(getRealResolverClass(it, options), this, it) }
}
override fun getFieldSearches() =
resolvers.map { FieldResolverScanner.Search(getRealResolverClass(it, options), this, it) }
}

internal class DataClassResolverInfo(private val dataClass: Class<*>): ResolverInfo() {
override fun getFieldSearches(): List<FieldResolverScanner.Search> {
return listOf(FieldResolverScanner.Search(dataClass, this, null))
}
override fun getFieldSearches() =
listOf(FieldResolverScanner.Search(dataClass, this, null))
}

internal class MissingResolverInfo: ResolverInfo() {
@@ -15,6 +15,10 @@ import java.util.Optional
*/
internal class TypeClassMatcher(private val definitionsByName: Map<String, TypeDefinition>) {

companion object {
fun isListType(realType: ParameterizedType, generic: GenericType) = generic.isTypeAssignableFromRawClass(realType, Iterable::class.java)
}

private fun error(potentialMatch: PotentialMatch, msg: String) = SchemaClassScannerError("Unable to match type definition (${potentialMatch.graphQLType}) with java type (${potentialMatch.javaType}): $msg")

fun match(potentialMatch: PotentialMatch): Match {
@@ -76,7 +80,7 @@ internal class TypeClassMatcher(private val definitionsByName: Map<String, TypeD
}
}

private fun isListType(realType: ParameterizedType, potentialMatch: PotentialMatch) = potentialMatch.generic.isTypeAssignableFromRawClass(realType, List::class.java)
private fun isListType(realType: ParameterizedType, potentialMatch: PotentialMatch) = isListType(realType, potentialMatch.generic)

private fun requireRawClass(type: JavaType): Class<*> {
if(type !is Class<*>) {
@@ -106,12 +110,11 @@ internal class TypeClassMatcher(private val definitionsByName: Map<String, TypeD

internal data class PotentialMatch(val graphQLType: GraphQLLangType, val javaType: JavaType, val generic: GenericType.RelativeTo, val reference: SchemaClassScanner.Reference, val location: Location, val batched: Boolean) {
companion object {
fun returnValue(graphQLType: GraphQLLangType, javaType: JavaType, generic: GenericType.RelativeTo, reference: SchemaClassScanner.Reference, batched: Boolean): PotentialMatch {
return PotentialMatch(graphQLType, javaType, generic, reference, Location.RETURN_TYPE, batched)
}
fun parameterType(graphQLType: GraphQLLangType, javaType: JavaType, generic: GenericType.RelativeTo, reference: SchemaClassScanner.Reference, batched: Boolean): PotentialMatch {
return PotentialMatch(graphQLType, javaType, generic, reference, Location.PARAMETER_TYPE, batched)
}
fun returnValue(graphQLType: GraphQLLangType, javaType: JavaType, generic: GenericType.RelativeTo, reference: SchemaClassScanner.Reference, batched: Boolean) =
PotentialMatch(graphQLType, javaType, generic, reference, Location.RETURN_TYPE, batched)

fun parameterType(graphQLType: GraphQLLangType, javaType: JavaType, generic: GenericType.RelativeTo, reference: SchemaClassScanner.Reference, batched: Boolean) =
PotentialMatch(graphQLType, javaType, generic, reference, Location.PARAMETER_TYPE, batched)
}
}
class RawClassRequiredForGraphQLMappingException(msg: String): RuntimeException(msg)
@@ -365,15 +365,14 @@ class EndToEndSpec extends Specification {
def data = Utils.assertNoGraphQlErrors(gql) {
'''
{
batched1: batchedEcho(msg: "hello")
batched2: batchedEcho(msg: ", ")
batched3: batchedEcho(msg: "world")
batched4: batchedEcho(msg: "!")
allBaseItems {
name: batchedName
}
}
'''
}

then:
data.batched1 == "hello"
data.allBaseItems.collect { it.name } == ['item1', 'item2']
}
}
@@ -1,5 +1,6 @@
package com.coxautodev.graphql.tools

import graphql.execution.batched.Batched
import graphql.language.InputObjectTypeDefinition
import graphql.language.InterfaceTypeDefinition
import graphql.language.ObjectTypeDefinition
@@ -370,4 +371,48 @@ class SchemaClassScannerSpec extends Specification {
String id
}
}

def "scanner throws if @Batched is used on root resolver"() {
when:
SchemaParser.newParser()
.schemaString('''
type Query {
test: String
}
''')
.resolvers(new GraphQLQueryResolver() {
@Batched List<String> test() { null }
})
.scan()

then:
def e = thrown(ResolverError)
e.message.contains('The @Batched annotation is only allowed on non-root resolver methods')
}

def "scanner throws if @Batched is used on data class"() {
when:
SchemaParser.newParser()
.schemaString('''
type Query {
test: DataClass
}
type DataClass {
test: String
}
''')
.resolvers(new GraphQLQueryResolver() {
DataClass test() { null }

class DataClass {
@Batched List<String> test() { null }
}
})
.scan()

then:
def e = thrown(ResolverError)
e.message.contains('The @Batched annotation is only allowed on non-root resolver methods')
}
}

0 comments on commit 73d42e7

Please sign in to comment.
You can’t perform that action at this time.