diff --git a/graphql_compiler/compiler/compiler_frontend.py b/graphql_compiler/compiler/compiler_frontend.py index 2928e540a..9a7883d36 100644 --- a/graphql_compiler/compiler/compiler_frontend.py +++ b/graphql_compiler/compiler/compiler_frontend.py @@ -964,7 +964,7 @@ def graphql_to_ir(schema, graphql_string, type_equivalence_hints=None): - ir_blocks: a list of IR basic block objects - input_metadata: a dict of expected input parameters (string) -> inferred GraphQL type - output_metadata: a dict of output name (string) -> OutputMetadata object - - location_types: a dict of location objects -> GraphQL type objects at that location + - query_metadata_table: a QueryMetadataTable object containing location metadata Raises flavors of GraphQLError in the following cases: - if the query is invalid GraphQL (GraphQLParsingError); diff --git a/graphql_compiler/compiler/metadata.py b/graphql_compiler/compiler/metadata.py index ecc761fc0..2c24d6137 100644 --- a/graphql_compiler/compiler/metadata.py +++ b/graphql_compiler/compiler/metadata.py @@ -44,6 +44,18 @@ def __init__(self, root_location, root_location_info): self._inputs = dict() # dict, input name -> input info namedtuple self._outputs = dict() # dict, output name -> output info namedtuple self._tags = dict() # dict, tag name -> tag info namedtuple + + # dict, revisiting Location -> revisit origin, i.e. the first Location with that query path + self._revisit_origins = dict() + + # dict, revisit origin Location -> set of Locations for which + # that Location is the revisit origin + self._revisits = dict() + + # dict, Location/FoldScopeLocation -> set of Location and FoldScopeLocation objects + # that are directly descended from it + self._child_locations = dict() + self.register_location(root_location, root_location_info) @property @@ -72,6 +84,8 @@ def register_location(self, location, location_info): raise AssertionError(u'All locations other than the root location and its revisits ' u'must have a parent location, but received a location with ' u'no parent: {} {}'.format(location, location_info)) + else: + self._child_locations.setdefault(location_info.parent_location, set()).add(location) self._locations[location] = location_info @@ -84,6 +98,13 @@ def revisit_location(self, location): # might still be holding on to the original info object, therefore registering stale data. # This function ensures that the latest metadata on the location is always used instead. revisited_location = location.revisit() + + # If "location" is itself a revisit, then we point "revisited_location" to "location"'s + # revisit origin. If "location" is not a revisit, then it itself is the revisit origin. + revisit_origin = self._revisit_origins.get(location, location) + self._revisit_origins[revisited_location] = revisit_origin + self._revisits.setdefault(revisit_origin, set()).add(revisited_location) + self.register_location(revisited_location, self.get_location_info(location)) return revisited_location @@ -111,6 +132,34 @@ def get_location_info(self, location): u'{}'.format(location)) return location_info + def get_child_locations(self, location): + """Yield an iterable of child locations for a given Location/FoldScopeLocation object.""" + self.get_location_info(location) # purely to check for location validity + + for child_location in self._child_locations.get(location, []): + yield child_location + + def get_all_revisits(self, location): + """Yield an iterable of locations that revisit that location or another of its revisits.""" + self.get_location_info(location) # purely to check for location validity + + for revisit_location in self._revisits.get(location, []): + yield revisit_location + + def get_revisit_origin(self, location): + """Return the original location that this location revisits, or None if it isn't a revisit. + + Args: + location: Location/FoldScopeLocation object whose revisit origin to get + + Returns: + Location object representing the first location with the same query path as the given + location. Returns the given location itself if that location is the first one with + that query path. Guaranteed to return the input location if it is a FoldScopeLocation. + """ + self.get_location_info(location) # purely to check for location validity + return self._revisit_origins.get(location, location) + @property def registered_locations(self): """Return an iterable of (location, location_info) tuples for all registered locations.""" diff --git a/graphql_compiler/tests/test_ir_generation.py b/graphql_compiler/tests/test_ir_generation.py index 575032676..f45acd7f7 100644 --- a/graphql_compiler/tests/test_ir_generation.py +++ b/graphql_compiler/tests/test_ir_generation.py @@ -35,6 +35,19 @@ def check_test_data(test_case, test_data, expected_blocks, expected_location_typ expected_location_types, get_comparable_location_types(compilation_results.query_metadata_table)) + all_child_locations, revisits = compute_child_and_revisit_locations(expected_blocks) + for parent_location, child_locations in six.iteritems(all_child_locations): + for child_location in child_locations: + child_info = compilation_results.query_metadata_table.get_location_info(child_location) + test_case.assertEqual(parent_location, child_info.parent_location) + + test_case.assertEqual( + all_child_locations, + get_comparable_child_locations(compilation_results.query_metadata_table)) + test_case.assertEqual( + revisits, + get_comparable_revisits(compilation_results.query_metadata_table)) + def get_comparable_location_types(query_metadata_table): """Return the dict of location -> GraphQL type name for each location in the query.""" @@ -44,6 +57,131 @@ def get_comparable_location_types(query_metadata_table): } +def get_comparable_child_locations(query_metadata_table): + """Return the dict of location -> set of child locations for each location in the query.""" + all_locations_with_possible_children = { + location: set(query_metadata_table.get_child_locations(location)) + for location, _ in query_metadata_table.registered_locations + } + return { + location: child_locations + for location, child_locations in six.iteritems(all_locations_with_possible_children) + if child_locations + } + + +def get_comparable_revisits(query_metadata_table): + """Return a dict location -> set of revisit locations for that starting location.""" + revisit_origins = { + query_metadata_table.get_revisit_origin(location) + for location, _ in query_metadata_table.registered_locations + } + + intermediate_result = { + location: set(query_metadata_table.get_all_revisits(location)) + for location in revisit_origins + } + + return { + location: revisits + for location, revisits in six.iteritems(intermediate_result) + if revisits + } + + +def compute_child_and_revisit_locations(ir_blocks): + """Return dicts describing the parent-child and revisit relationships for all query locations. + + Args: + ir_blocks: list of IR blocks describing the given query + + Returns: + tuple of: + dict mapping parent location -> set of child locations (guaranteed to be non-empty) + dict mapping revisit origin -> set of revisits (possibly empty) + """ + if not ir_blocks: + raise AssertionError(u'Unexpectedly received empty ir_blocks: {}'.format(ir_blocks)) + + first_block = ir_blocks[0] + if not isinstance(first_block, blocks.QueryRoot): + raise AssertionError(u'Unexpectedly, the first IR block was not a QueryRoot: {} {}' + .format(first_block, ir_blocks)) + + # These block types do not affect the computed location structure. + no_op_block_types = ( + blocks.Filter, + blocks.ConstructResult, + blocks.EndOptional, + blocks.OutputSource, + blocks.CoerceType, + ) + + current_location = None + traversed_or_recursed_or_folded = False + fold_started_at = None + + top_level_locations = set() + parent_location = dict() # location -> parent location + child_locations = dict() # location -> set of child locations + revisits = dict() # location -> set of revisit locations + query_path_to_revisit_origin = dict() # location query path -> its revisit origin + + # Walk the IR blocks and reconstruct the query's location structure. + for block in ir_blocks[1:]: + if isinstance(block, (blocks.Traverse, blocks.Fold, blocks.Recurse)): + traversed_or_recursed_or_folded = True + if isinstance(block, blocks.Fold): + fold_started_at = current_location + elif isinstance(block, blocks.Unfold): + current_location = fold_started_at + elif isinstance(block, blocks.MarkLocation): + # Handle optional traversals and backtracks, due to the fact that + # they might drop MarkLocations before and after themselves. + if traversed_or_recursed_or_folded: + block_parent_location = current_location + else: + block_parent_location = parent_location.get(current_location, None) + + if block_parent_location is not None: + parent_location[block.location] = block_parent_location + child_locations.setdefault(block_parent_location, set()).add(block.location) + else: + top_level_locations.add(current_location) + + current_location = block.location + + if isinstance(current_location, helpers.FoldScopeLocation): + revisit_origin = None + elif isinstance(current_location, helpers.Location): + if current_location.query_path not in query_path_to_revisit_origin: + query_path_to_revisit_origin[current_location.query_path] = current_location + revisit_origin = None + else: + revisit_origin = query_path_to_revisit_origin[current_location.query_path] + else: + raise AssertionError(u'Unreachable state reached: {} {}' + .format(current_location, ir_blocks)) + + if revisit_origin is not None: + revisits.setdefault(revisit_origin, set()).add(current_location) + + traversed_or_recursed_or_folded = False + elif isinstance(block, blocks.Backtrack): + current_location = block.location + elif isinstance(block, no_op_block_types): + # These blocks do not affect the computed location structure. + pass + elif isinstance(block, blocks.QueryRoot): + raise AssertionError(u'Unexpectedly encountered a second QueryRoot after the first ' + u'IR block: {} {}'.format(block, ir_blocks)) + else: + raise AssertionError(u'Unexpected block type encountered: {} {}' + .format(block, ir_blocks)) + + return child_locations, revisits + + class IrGenerationTests(unittest.TestCase): """Ensure valid inputs produce correct IR."""