Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fixes #26: Support Mutations for Relationships #32

Merged
merged 1 commit into from
Jun 1, 2019
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
22 changes: 21 additions & 1 deletion src/main/kotlin/org/neo4j/graphql/CrudOperations.kt
Original file line number Diff line number Diff line change
@@ -1,12 +1,13 @@
package org.neo4j.graphql

import graphql.language.FieldDefinition
import graphql.language.ListType
import graphql.language.ObjectTypeDefinition

data class Augmentation(val create: String = "", val merge: String = "", val update: String = "", val delete: String = "",
val inputType: String = "", val ordering: String = "", val filterType: String = "", val query: String = "")

fun augmentedSchema(ctx: Translator.Context, type: ObjectTypeDefinition): Augmentation {
fun createNodeMutation(ctx: Translator.Context, type: ObjectTypeDefinition): Augmentation {
val typeName = type.name
val idField = type.fieldDefinitions.find { it.type.name() == "ID" }
val scalarFields = type.fieldDefinitions.filter { it.type.isScalar() }.sortedByDescending { it == idField }
Expand Down Expand Up @@ -34,6 +35,25 @@ fun augmentedSchema(ctx: Translator.Context, type: ObjectTypeDefinition): Augmen
} else result
}

fun createRelationshipMutation(ctx: Translator.Context, source: ObjectTypeDefinition, target: ObjectTypeDefinition): Augmentation? {
val sourceTypeName = source.name
return if (!ctx.mutation.enabled || ctx.mutation.exclude.contains(sourceTypeName)) {
null
} else {
val targetField = source.getFieldByType(target.name) ?: return null
val sourceIdField = source.fieldDefinitions.find { it.isID() }
val targetIdField = target.fieldDefinitions.find { it.isID() }
if (sourceIdField == null || targetIdField == null) {
return null
}
val targetFieldName = targetField.name.capitalize()
val targetIDStr = if (targetField.isList()) "[ID!]!" else "ID!"
Augmentation(
create = "add$sourceTypeName$targetFieldName(${sourceIdField.name}:ID!, ${targetField.name}:$targetIDStr) : $sourceTypeName",
delete = "delete$sourceTypeName$targetFieldName(${sourceIdField.name}:ID!, ${targetField.name}:$targetIDStr) : $sourceTypeName")
}
}

private fun filterType(name: String?, fieldArgs: List<FieldDefinition>) : String {
val fName = """_${name}Filter"""
val fields = (listOf("AND","OR","NOT").map { "$it:[$fName!]" } +
Expand Down
5 changes: 4 additions & 1 deletion src/main/kotlin/org/neo4j/graphql/ExtensionFunctions.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
package org.neo4j.graphql

import graphql.Scalars
import graphql.language.FieldDefinition
import graphql.schema.GraphQLFieldDefinition
import java.io.PrintWriter
import java.io.StringWriter

Expand All @@ -11,4 +14,4 @@ fun Throwable.stackTraceAsString(): String {

fun <T> Iterable<T>.joinNonEmpty(separator: CharSequence = ", ", prefix: CharSequence = "", postfix: CharSequence = "", limit: Int = -1, truncated: CharSequence = "...", transform: ((T) -> CharSequence)? = null): String {
return if (iterator().hasNext()) joinTo(StringBuilder(), separator, prefix, postfix, limit, truncated, transform).toString() else ""
}
}
12 changes: 12 additions & 0 deletions src/main/kotlin/org/neo4j/graphql/GraphQLExtensions.kt
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
package org.neo4j.graphql

import graphql.Scalars
import graphql.language.*
import graphql.schema.*

Expand Down Expand Up @@ -85,3 +86,14 @@ fun paramName(variable: String, argName: String, value: Any?) = when (value) {
is VariableReference -> value.name
else -> "$variable${argName.capitalize()}"
}

fun FieldDefinition.isID(): Boolean = this.type.name() == "ID"
fun FieldDefinition.isList(): Boolean = this.type is ListType
fun GraphQLFieldDefinition.isID(): Boolean = this.type.inner() == Scalars.GraphQLID
fun ObjectTypeDefinition.getFieldByType(typeName: String): FieldDefinition? = this.fieldDefinitions
.filter { it.type.inner().name() == typeName }
.firstOrNull()

fun GraphQLDirective.getRelationshipType(): String = this.getArgument("name").value.toString()
fun GraphQLDirective.getRelationshipDirection(): String = this.getArgument("direction")?.value?.toString() ?: "OUT"

93 changes: 86 additions & 7 deletions src/main/kotlin/org/neo4j/graphql/Translator.kt
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,36 @@ class Translator(val schema: GraphQLSchema) {
val mutation: CRUDConfig = CRUDConfig())
data class CRUDConfig(val enabled:Boolean = true, val exclude: List<String> = emptyList())
data class Cypher( val query: String, val params : Map<String,Any?> = emptyMap()) {
fun with(p: Map<String,Any?>) = this.copy(params = this.params + p)
fun escapedQuery() = query.replace("\"","\\\"").replace("'","\\'")
companion object {
val EMPTY = Cypher("")

private fun findRelNodeId(objectType: GraphQLObjectType) = objectType.fieldDefinitions.find { it.isID() }!!

private fun createRelStatement(source: GraphQLType, target: GraphQLFieldDefinition,
keyword: String = "MERGE"): String {
val innerTarget = target.type.inner()
val relationshipDirective = target.getDirective("relation")
?: throw IllegalArgumentException("Missing @relation directive for relation ${target.name}")
val targetFilterType = if (target.type.isList()) "IN" else "="
val sourceId = findRelNodeId(source as GraphQLObjectType)
val targetId = findRelNodeId(innerTarget as GraphQLObjectType)
val (left, right) = if (relationshipDirective.getRelationshipDirection() == "OUT") ("" to ">") else ("<" to "")
return "MATCH (from:${source.name.quote()} {${sourceId.name.quote()}:$${sourceId.name}}) " +
"MATCH (to:${innerTarget.name.quote()}) WHERE to.${targetId.name.quote()} $targetFilterType $${target.name} " +
"$keyword (from)$left-[r:${relationshipDirective.getRelationshipType().quote()}]-$right(to) "
}

fun createRelationship(source: GraphQLType, target: GraphQLFieldDefinition): Cypher {
return Cypher(createRelStatement(source, target))
}

fun deleteRelationship(source: GraphQLType, target: GraphQLFieldDefinition): Cypher {
return Cypher(createRelStatement(source, target, "MATCH") +
"DELETE r ")
}
}
fun with(p: Map<String,Any?>) = this.copy(params = this.params + p)
fun escapedQuery() = query.replace("\"","\\\"").replace("'","\\'")
}

@JvmOverloads fun translate(query: String, params: Map<String, Any> = emptyMap(), context: Context = Context()) : List<Cypher> {
Expand All @@ -38,8 +63,8 @@ class Translator(val schema: GraphQLSchema) {

private fun toQuery(field: Field, ctx:Context = Context()): Cypher {
val name = field.name
val queryType = schema.queryType.fieldDefinitions.filter { it.name == name }.firstOrNull()
val mutationType = schema.mutationType.fieldDefinitions.filter { it.name == name }.firstOrNull()
val queryType = schema.queryType.getFieldDefinition(name)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

good change

val mutationType = schema.mutationType.getFieldDefinition(name)
val fieldDefinition = queryType ?: mutationType
?: throw IllegalArgumentException("Unknown Query $name available queries: " + (schema.queryType.fieldDefinitions + schema.mutationType.fieldDefinitions).map { it.name }.joinToString())
val isQuery = queryType != null
Expand All @@ -63,7 +88,7 @@ class Translator(val schema: GraphQLSchema) {
return Cypher("MATCH ($variable:$label${properties.query})${where.query} RETURN ${mapProjection.query} AS $variable$ordering$skipLimit",
(mapProjection.params + properties.params + where.params))
} else {
// todo extract method or better object
// TODO add into Cypher companion object as did for the relationships
val properties = properties(variable, fieldDefinition, propertyArguments(field))
val idProperty = fieldDefinition.arguments.find { it.type.inner() == Scalars.GraphQLID }
val returnStatement = "WITH $variable RETURN ${mapProjection.query} AS $variable$ordering$skipLimit";
Expand All @@ -86,12 +111,48 @@ class Translator(val schema: GraphQLSchema) {
"WITH $variable as toDelete, ${mapProjection.query} AS $variable $ordering$skipLimit DETACH DELETE toDelete RETURN $variable",
(mapProjection.params + mapOf(paramName to properties.params[paramName])))
}
else -> throw IllegalArgumentException("Unknown Mutation "+name)
else -> checkRelationships(fieldDefinition, field, ordering, skipLimit, ctx)
}
}
}
}

private fun checkRelationships(sourceFieldDefinition: GraphQLFieldDefinition, field: Field, ordering: String, skipLimit: String, ctx: Context): Cypher {
val source = sourceFieldDefinition.type as GraphQLObjectType
val targetFieldDefinition = filterTarget(source, field, sourceFieldDefinition)

val sourceVariable = "from"
val mapProjection = projectFields(sourceVariable, field, source, ctx, null)
val returnStatement = "WITH DISTINCT $sourceVariable RETURN ${mapProjection.query} AS ${source.name.decapitalize().quote()}$ordering$skipLimit"
val properties = properties("", sourceFieldDefinition, propertyArguments(field)).params
.mapKeys { it.key.decapitalize() }
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

do we have fields that are capitalized?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

in this case yes because the ´properties´ method gives the name to the fields starting from the first argument which is, in this case, an empty string


val targetFieldName = targetFieldDefinition.name
val addMutationName = "add${source.name}${targetFieldName.capitalize()}"
val deleteMutationName = "delete${source.name}${targetFieldName.capitalize()}"
return when (field.name) {
addMutationName -> {
Cypher.createRelationship(source, targetFieldDefinition)
}
deleteMutationName -> {
Cypher.deleteRelationship(source, targetFieldDefinition)
}
else -> throw IllegalArgumentException("Unknown Mutation ${sourceFieldDefinition.name}")
}.let {
it.copy(query = it.query + returnStatement, params = properties)
}
}

private fun filterTarget(source: GraphQLObjectType, field: Field, graphQLFieldDefinition: GraphQLFieldDefinition): GraphQLFieldDefinition {
return source.fieldDefinitions
.filter {
it.name.isNotBlank() && (field.name == "add${source.name}${it.name.capitalize()}"
|| field.name == "delete${source.name}${it.name.capitalize()}")
}
.map { it }
.firstOrNull() ?: throw IllegalArgumentException("Unknown Mutation ${graphQLFieldDefinition.name}")
}

private fun cypherQueryOrMutation(variable: String, fieldDefinition: GraphQLFieldDefinition, field: Field, cypherDirective: Cypher, mapProjection: Cypher, ordering: String, skipLimit: String, isQuery: Boolean) =
if (isQuery) {
val (query, params) = cypherDirective(variable, fieldDefinition, field, cypherDirective, emptyList())
Expand Down Expand Up @@ -403,7 +464,12 @@ object SchemaBuilder {
"""
typeDefinitionRegistry.merge(schemaParser.parse(directivesSdl))

val augmentations = typeDefinitionRegistry.types().values.filterIsInstance<ObjectTypeDefinition>().map { augmentedSchema(ctx, it) }
val objectTypeDefinitions = typeDefinitionRegistry.types().values.filterIsInstance<ObjectTypeDefinition>()
val nodeMutations = objectTypeDefinitions.map { createNodeMutation(ctx, it) }
val relMutations = objectTypeDefinitions.flatMap { source ->
createRelationshipMutations(source, objectTypeDefinitions, ctx)
}
val augmentations = nodeMutations + relMutations

val augmentedTypesSdl = augmentations.flatMap { listOf(it.filterType, it.ordering, it.inputType).filter { it.isNotBlank() } }.joinToString("\n")
typeDefinitionRegistry.merge(schemaParser.parse(augmentedTypesSdl))
Expand Down Expand Up @@ -454,4 +520,17 @@ object SchemaBuilder {

return typeDefinitionRegistry
}

private fun createRelationshipMutations(source: ObjectTypeDefinition,
objectTypeDefinitions: List<ObjectTypeDefinition>,
ctx: Translator.Context): List<Augmentation> = source.fieldDefinitions
.filter { !it.type.inner().isScalar() && it.getDirective("relation") != null }
.mapNotNull { targetField ->
objectTypeDefinitions
.filter { it.name == targetField.type.inner().name() }
.firstOrNull()
?.let { target ->
createRelationshipMutation(ctx, source, target)
}
}
}