Skip to content

Commit

Permalink
Merge a4caba9 into 3281eed
Browse files Browse the repository at this point in the history
  • Loading branch information
bojanserafimov committed Mar 14, 2019
2 parents 3281eed + a4caba9 commit 7c72c3f
Show file tree
Hide file tree
Showing 3 changed files with 334 additions and 26 deletions.
267 changes: 264 additions & 3 deletions graphql_compiler/macros/macro_edge/helpers.py
@@ -1,10 +1,14 @@
# Copyright 2019-present Kensho Technologies, LLC.
from copy import copy

from graphql.language.ast import Field, InlineFragment, OperationDefinition, SelectionSet
from graphql.language.ast import (
Argument, Field, InlineFragment, ListValue, Name, OperationDefinition, SelectionSet,
StringValue
)

from ...ast_manipulation import get_ast_field_name
from ...compiler.helpers import get_field_type_from_schema, get_vertex_field_type
from ...schema import FilterDirective, TagDirective
from ..macro_edge.directives import MacroEdgeTargetDirective


Expand Down Expand Up @@ -53,6 +57,262 @@ def get_directives_for_ast(ast):
return result


def get_all_tag_names(ast):
"""Return a set of strings containing tag names that appear in the query.
Args:
ast: GraphQL query AST object
"""
return {
# Schema validation has ensured this exists
directive.arguments[0].value.value
for ast, directive in _yield_ast_nodes_with_directives(ast)
if directive.name.value == TagDirective.name
}


def _remove_colocated_tags_at_ast_node(non_macro_names, ast):
"""Return an AST node with at most one tag per field by removing tags.
Args:
non_macro_names: set of tag names that the user wrote. We prefer keeping these names.
ast: GraphQL query AST object that potentially has multiple colocated tags. This AST
should not contain any duplicate tags (different tags with the same name).
Returns:
tuple(made_changes, new_ast, name_change_map).
made_changes is a bool describing whether there were any tags to remove. new_ast is
the ast with at most one tag per field. name_change_map is a dict (string -> string)
that contains the new name for each tag name. Names of removed tags are mapped to
the name of the colocated tag that was not removed.
"""
new_directives = ast.directives
tag_name_list = [
directive.arguments[0].value.value
for directive in ast.directives
if directive.name.value == TagDirective.name
]
tag_names = set(tag_name_list)
if len(tag_name_list) != len(tag_names):
raise AssertionError(u'The ast should not contain multiple tags with '
u'the same name. {}'.format(tag_name_list))
made_changes = len(tag_names) > 1
name_change_map = dict()
if tag_names:
# Find which tag to keep
name_to_use = None
user_specified_names = tag_names & non_macro_names
if len(user_specified_names) == 0:
name_to_use = min(tag_names, key=len)
elif len(user_specified_names) == 1:
name_to_use = next(iter(user_specified_names))
else:
raise AssertionError(u'Multiple tags on the same field are not allowed: {}'
.format(user_specified_names))

# Remove all tags but the one we decided to keep
name_change_map = {
name: name_to_use
for name in tag_names
}
new_directives = [
directive
for directive in ast.directives
if (directive.name.value != TagDirective.name or
directive.arguments[0].value.value == name_to_use)
]

return made_changes, new_directives, name_change_map


def _remove_colocated_tags(non_macro_names, ast):
"""Return an AST with at most one tag per field by removing tags.
See merge_colocated_tags in this module for an example usage.
Args:
non_macro_names: set of tag names that the user wrote. We prefer keeping these names.
ast: GraphQL query AST object that potentially has multiple colocated tags. This AST
should not contain any duplicate tags (different tags with the same name).
Returns:
tuple(new_ast, name_change_map). new_ast is the ast with at most one tag per field.
name_change_map is a dict (string -> string) that contains the new name for each
tag name. Names of removed tags are mapped to the name of the colocated tag that was
not removed.
"""
if not isinstance(ast, (Field, InlineFragment, OperationDefinition)):
return ast

made_changes = False
name_change_map = dict()

# Recurse
new_selection_set = None
if ast.selection_set is not None:
new_selections = []
for selection_ast in ast.selection_set.selections:
new_selection_ast, inner_name_change_map = _remove_colocated_tags(
non_macro_names, selection_ast)

name_collisions = set(name_change_map.keys()) & set(inner_name_change_map.keys())
if name_collisions:
raise AssertionError(u'The ast should not contain multiple tags with '
u'the same name. Found duplicate on names: {}'
.format(name_collisions))
name_change_map.update(inner_name_change_map)

if selection_ast is not new_selection_ast:
# Since we did not get the exact same object as the input, changes were made.
# That means this call will also need to make changes and return a new object.
made_changes = True

new_selections.append(new_selection_ast)
new_selection_set = SelectionSet(new_selections)

# Process the current node
made_changes_at_this_node, new_directives, new_name_change_map = \
_remove_colocated_tags_at_ast_node(non_macro_names, ast)
made_changes = made_changes or made_changes_at_this_node
name_change_map.update(new_name_change_map)

if not made_changes:
# We didn't change anything, return the original input object.
return ast, name_change_map

new_ast = copy(ast)
new_ast.selection_set = new_selection_set
new_ast.directives = new_directives
return new_ast, name_change_map


def _replace_tag_names_at_current_node(name_change_map, ast):
"""Replace tag names that are already in use at the root of the AST."""
# Rename tag names in @tag and @filter directives, and record if we made changes
made_changes = False
new_directives = []
for directive in ast.directives:
if directive.name.value == TagDirective.name:
current_name = directive.arguments[0].value.value
new_name = name_change_map[current_name]
if new_name != current_name:
made_changes = True
renamed_tag_directive = copy(directive)
renamed_tag_directive.arguments = [Argument(Name('tag_name'), StringValue(new_name))]
new_directives.append(renamed_tag_directive)
elif directive.name.value == FilterDirective.name:
filter_with_renamed_args = copy(directive)
filter_with_renamed_args.arguments = []
for argument in directive.arguments:
if argument.name.value == 'op_name':
filter_with_renamed_args.arguments.append(argument)
elif argument.name.value == 'value':
new_value_list = []
for value in argument.value.values:
if value.value.startswith('%'):
current_name = value.value[1:]
new_name = name_change_map[current_name]
if new_name != current_name:
made_changes = True
new_value_list.append(StringValue('%' + new_name))
else:
new_value_list.append(value)
filter_with_renamed_args.arguments.append(
Argument(Name('value'), value=ListValue(new_value_list)))
else:
raise AssertionError(u'Unknown argument name {} in filter'
.format(argument.name.value))
new_directives.append(filter_with_renamed_args)
else:
new_directives.append(directive)
return made_changes, new_directives


def replace_tag_names(name_change_map, ast):
"""Replace tag names that are already in use in the whole AST."""
if not isinstance(ast, (Field, InlineFragment, OperationDefinition)):
return ast

made_changes = False

# Recurse
new_selection_set = None
if ast.selection_set is not None:
new_selections = []
for selection_ast in ast.selection_set.selections:
new_selection_ast = replace_tag_names(name_change_map, selection_ast)

if selection_ast is not new_selection_ast:
# Since we did not get the exact same object as the input, changes were made.
# That means this call will also need to make changes and return a new object.
made_changes = True

new_selections.append(new_selection_ast)
new_selection_set = SelectionSet(new_selections)

# Process the current node
made_changes_at_this_node, new_directives = _replace_tag_names_at_current_node(
name_change_map, ast)
made_changes = made_changes or made_changes_at_this_node

if not made_changes:
# We didn't change anything, return the original input object.
return ast

new_ast = copy(ast)
new_ast.selection_set = new_selection_set
new_ast.directives = new_directives
return new_ast


def merge_colocated_tags(non_macro_names, ast):
"""Return an AST with at most one tag per field by removing tags and renaming their uses.
If both the macro and the user of a macro tagged the same field, the resulting query after
expansion will have two tags on the same field. We use this function to simplify such
queries.
Filters that use the values of the removed tags will instead use the value of a different
tag that was on the same field and not removed.
Args:
non_macro_names: set of tag names that the user wrote. We prefer keeping these names.
ast: GraphQL query AST object that potentially has multiple colocated tags. This AST
should not contain any duplicate tags (different tags with the same name).
Returns:
tuple(new_ast, name_change_map). new_ast is the ast with at most one tag per field.
name_change_map is a dict (string -> string) that contains the new name for each
tag name. Names of removed tags are mapped to the name of the colocated tag that was
not removed.
"""
deduplicated_ast, name_change_map = _remove_colocated_tags(non_macro_names, ast)
return replace_tag_names(name_change_map, deduplicated_ast)


def generate_disambiguations(existing_names, new_names):
"""Return a dict mapping the new names to similar names not conflicting with existing names.
Args:
existing_names: set of strings, the names that are already taken
new_names: set of strings, the names that might coincide with exisitng names
Returns:
dict mapping the new names to other unique names not present in existing_names
"""
name_mapping = dict()
for name in new_names:
# We try adding different suffixes to disambiguate from the existing names. There will
# be no collisions among the disambiguations because they will all have unique prefixes.
disambiguation = name
index = 0
while disambiguation in existing_names:
disambiguation = disambiguation + '_copy_' + str(index)
index += 1
name_mapping[name] = disambiguation
return name_mapping


def remove_directives_from_ast(ast, directive_names_to_omit):
"""Return an equivalent AST to the input, but with instances of the named directives omitted.
Expand All @@ -70,7 +330,7 @@ def remove_directives_from_ast(ast, directive_names_to_omit):

made_changes = False

new_selections = None
new_selection_set = None
if ast.selection_set is not None:
new_selections = []
for selection_ast in ast.selection_set.selections:
Expand All @@ -82,6 +342,7 @@ def remove_directives_from_ast(ast, directive_names_to_omit):
made_changes = True

new_selections.append(new_selection_ast)
new_selection_set = SelectionSet(new_selections)

directives_to_keep = [
directive
Expand All @@ -96,7 +357,7 @@ def remove_directives_from_ast(ast, directive_names_to_omit):
return ast

new_ast = copy(ast)
new_ast.selection_set = SelectionSet(new_selections)
new_ast.selection_set = new_selection_set
new_ast.directives = directives_to_keep
return new_ast

Expand Down

0 comments on commit 7c72c3f

Please sign in to comment.