Skip to content

Commit

Permalink
Capture child and parent location relationships in metadata.
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi committed Oct 1, 2018
1 parent 1b2ae07 commit 20a2b3c
Show file tree
Hide file tree
Showing 3 changed files with 107 additions and 1 deletion.
2 changes: 1 addition & 1 deletion graphql_compiler/compiler/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -964,7 +964,7 @@ def graphql_to_ir(schema, graphql_string, type_equivalence_hints=None):
- ir_blocks: a list of IR basic block objects
- input_metadata: a dict of expected input parameters (string) -> inferred GraphQL type
- output_metadata: a dict of output name (string) -> OutputMetadata object
- location_types: a dict of location objects -> GraphQL type objects at that location
- query_metadata_table: a QueryMetadataTable object containing location metadata
Raises flavors of GraphQLError in the following cases:
- if the query is invalid GraphQL (GraphQLParsingError);
Expand Down
23 changes: 23 additions & 0 deletions graphql_compiler/compiler/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,13 @@ def __init__(self, root_location, root_location_info):

self._root_location = root_location # Location, the root location of the entire query
self._locations = dict() # dict, Location/FoldScopeLocation -> LocationInfo
self._revisit_origins = dict() # dict, revisiting Location -> revisit origin, i.e.
# the first Location with that query path
self._revisits = dict() # dict, revisit origin Location -> set of Locations
# for which that Location is the revisit origin
self._child_locations = dict() # dict, Location/FoldScopeLocation -> set of Location
# and FoldScopeLocation objects that are directly
# descended from it
self._inputs = dict() # dict, input name -> input info namedtuple
self._outputs = dict() # dict, output name -> output info namedtuple
self._tags = dict() # dict, tag name -> tag info namedtuple
Expand Down Expand Up @@ -72,6 +79,8 @@ def register_location(self, location, location_info):
raise AssertionError(u'All locations other than the root location and its revisits '
u'must have a parent location, but received a location with '
u'no parent: {} {}'.format(location, location_info))
else:
self._child_locations.setdefault(location_info.parent_location, set()).add(location)

self._locations[location] = location_info

Expand All @@ -84,6 +93,13 @@ def revisit_location(self, location):
# might still be holding on to the original info object, therefore registering stale data.
# This function ensures that the latest metadata on the location is always used instead.
revisited_location = location.revisit()

# If "location" is itself a revisit, then we point "revisited_location" to "location"'s
# revisit origin. If "location" is not a revisit, then it itself is the revisit origin.
revisit_origin = self._revisit_origins.get(location, location)
self._revisit_origins[revisited_location] = revisit_origin
self._revisits.setdefault(revisit_origin, set()).add(revisited_location)

self.register_location(revisited_location, self.get_location_info(location))
return revisited_location

Expand Down Expand Up @@ -111,6 +127,13 @@ def get_location_info(self, location):
u'{}'.format(location))
return location_info

def get_child_locations(self, location):
"""Yield an iterable of child locations for a given Location/FoldScopeLocation object."""
self.get_location_info(location) # purely to check for location validity

for child_location in self._child_locations.get(location, []):
yield child_location

@property
def registered_locations(self):
"""Return an iterable of (location, location_info) tuples for all registered locations."""
Expand Down
83 changes: 83 additions & 0 deletions graphql_compiler/tests/test_ir_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,9 @@ def check_test_data(test_case, test_data, expected_blocks, expected_location_typ
test_case.assertEqual(
expected_location_types,
get_comparable_location_types(compilation_results.query_metadata_table))
test_case.assertEqual(
compute_child_locations(expected_blocks),
get_comparable_child_locations(compilation_results.query_metadata_table))


def get_comparable_location_types(query_metadata_table):
Expand All @@ -44,6 +47,86 @@ def get_comparable_location_types(query_metadata_table):
}


def get_comparable_child_locations(query_metadata_table):
"""Return the dict of location -> set of child locations for each location in the query."""
all_locations = {
location: set(query_metadata_table.get_child_locations(location))
for location, _ in query_metadata_table.registered_locations
}
return {
location: child_locations
for location, child_locations in six.iteritems(all_locations)
if child_locations
}


def compute_child_locations(ir_blocks):
"""Return a dict mapping parent location -> set of child locations, based on the IR blocks."""
if not ir_blocks:
raise AssertionError(u'Unexpectedly received empty ir_blocks: {}'.format(ir_blocks))

first_block = ir_blocks[0]
if not isinstance(first_block, blocks.QueryRoot):
raise AssertionError(u'Unexpectedly, the first IR block was not a QueryRoot: {} {}'
.format(first_block, ir_blocks))

# These block types do not affect the computed location structure.
no_op_block_types = (
blocks.Filter,
blocks.ConstructResult,
blocks.EndOptional,
blocks.OutputSource,
blocks.CoerceType,
)

current_location = None
traversed_or_recursed_or_folded = False
fold_started_at = None

top_level_locations = set()
parent_location = dict() # location -> parent location
child_locations = dict() # location -> set of child locations

# Walk the IR blocks and reconstruct the query's location structure.
for block in ir_blocks[1:]:
if isinstance(block, (blocks.Traverse, blocks.Fold, blocks.Recurse)):
traversed_or_recursed_or_folded = True
if isinstance(block, blocks.Fold):
fold_started_at = current_location
elif isinstance(block, blocks.Unfold):
current_location = fold_started_at
elif isinstance(block, blocks.MarkLocation):
# Handle optional traversals and backtracks, due to the fact that
# they might drop MarkLocations before and after themselves.
if traversed_or_recursed_or_folded:
block_parent_location = current_location
else:
block_parent_location = parent_location.get(current_location, None)

if block_parent_location is not None:
parent_location[block.location] = block_parent_location
child_locations.setdefault(block_parent_location, set()).add(block.location)
else:
top_level_locations.add(current_location)

current_location = block.location
traversed_or_recursed_or_folded = False
elif isinstance(block, blocks.Backtrack):
current_location = block.location
traversed_or_recursed_or_folded = False
elif isinstance(block, blocks.QueryRoot):
raise AssertionError(u'Unexpectedly encountered a second QueryRoot after the first '
u'IR block: {} {}'.format(block, ir_blocks))
elif isinstance(block, no_op_block_types):
# These blocks do not affect the computed location structure.
pass
else:
raise AssertionError(u'Unexpected block type encountered: {} {}'
.format(block, ir_blocks))

return child_locations


class IrGenerationTests(unittest.TestCase):
"""Ensure valid inputs produce correct IR."""

Expand Down

0 comments on commit 20a2b3c

Please sign in to comment.