Skip to content

Commit

Permalink
Merge 1bfa521 into 87bcf4d
Browse files Browse the repository at this point in the history
  • Loading branch information
obi1kenobi committed Oct 1, 2018
2 parents 87bcf4d + 1bfa521 commit 97f48cc
Show file tree
Hide file tree
Showing 3 changed files with 188 additions and 1 deletion.
2 changes: 1 addition & 1 deletion graphql_compiler/compiler/compiler_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -969,7 +969,7 @@ def graphql_to_ir(schema, graphql_string, type_equivalence_hints=None):
- ir_blocks: a list of IR basic block objects
- input_metadata: a dict of expected input parameters (string) -> inferred GraphQL type
- output_metadata: a dict of output name (string) -> OutputMetadata object
- location_types: a dict of location objects -> GraphQL type objects at that location
- query_metadata_table: a QueryMetadataTable object containing location metadata
Raises flavors of GraphQLError in the following cases:
- if the query is invalid GraphQL (GraphQLParsingError);
Expand Down
49 changes: 49 additions & 0 deletions graphql_compiler/compiler/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,18 @@ def __init__(self, root_location, root_location_info):
self._inputs = dict() # dict, input name -> input info namedtuple
self._outputs = dict() # dict, output name -> output info namedtuple
self._tags = dict() # dict, tag name -> tag info namedtuple

# dict, revisiting Location -> revisit origin, i.e. the first Location with that query path
self._revisit_origins = dict()

# dict, revisit origin Location -> set of Locations for which
# that Location is the revisit origin
self._revisits = dict()

# dict, Location/FoldScopeLocation -> set of Location and FoldScopeLocation objects
# that are directly descended from it
self._child_locations = dict()

self.register_location(root_location, root_location_info)

@property
Expand Down Expand Up @@ -72,6 +84,8 @@ def register_location(self, location, location_info):
raise AssertionError(u'All locations other than the root location and its revisits '
u'must have a parent location, but received a location with '
u'no parent: {} {}'.format(location, location_info))
else:
self._child_locations.setdefault(location_info.parent_location, set()).add(location)

self._locations[location] = location_info

Expand All @@ -84,6 +98,13 @@ def revisit_location(self, location):
# might still be holding on to the original info object, therefore registering stale data.
# This function ensures that the latest metadata on the location is always used instead.
revisited_location = location.revisit()

# If "location" is itself a revisit, then we point "revisited_location" to "location"'s
# revisit origin. If "location" is not a revisit, then it itself is the revisit origin.
revisit_origin = self._revisit_origins.get(location, location)
self._revisit_origins[revisited_location] = revisit_origin
self._revisits.setdefault(revisit_origin, set()).add(revisited_location)

self.register_location(revisited_location, self.get_location_info(location))
return revisited_location

Expand Down Expand Up @@ -111,6 +132,34 @@ def get_location_info(self, location):
u'{}'.format(location))
return location_info

def get_child_locations(self, location):
"""Yield an iterable of child locations for a given Location/FoldScopeLocation object."""
self.get_location_info(location) # purely to check for location validity

for child_location in self._child_locations.get(location, []):
yield child_location

def get_all_revisits(self, location):
"""Yield an iterable of locations that revisit that location or another of its revisits."""
self.get_location_info(location) # purely to check for location validity

for revisit_location in self._revisits.get(location, []):
yield revisit_location

def get_revisit_origin(self, location):
"""Return the original location that this location revisits, or None if it isn't a revisit.
Args:
location: Location/FoldScopeLocation object whose revisit origin to get
Returns:
Location object representing the first location with the same query path as the given
location. Returns the given location itself if that location is the first one with
that query path. Guaranteed to return the input location if it is a FoldScopeLocation.
"""
self.get_location_info(location) # purely to check for location validity
return self._revisit_origins.get(location, location)

@property
def registered_locations(self):
"""Return an iterable of (location, location_info) tuples for all registered locations."""
Expand Down
138 changes: 138 additions & 0 deletions graphql_compiler/tests/test_ir_generation.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,19 @@ def check_test_data(test_case, test_data, expected_blocks, expected_location_typ
expected_location_types,
get_comparable_location_types(compilation_results.query_metadata_table))

all_child_locations, revisits = compute_child_and_revisit_locations(expected_blocks)
for parent_location, child_locations in six.iteritems(all_child_locations):
for child_location in child_locations:
child_info = compilation_results.query_metadata_table.get_location_info(child_location)
test_case.assertEqual(parent_location, child_info.parent_location)

test_case.assertEqual(
all_child_locations,
get_comparable_child_locations(compilation_results.query_metadata_table))
test_case.assertEqual(
revisits,
get_comparable_revisits(compilation_results.query_metadata_table))


def get_comparable_location_types(query_metadata_table):
"""Return the dict of location -> GraphQL type name for each location in the query."""
Expand All @@ -44,6 +57,131 @@ def get_comparable_location_types(query_metadata_table):
}


def get_comparable_child_locations(query_metadata_table):
"""Return the dict of location -> set of child locations for each location in the query."""
all_locations_with_possible_children = {
location: set(query_metadata_table.get_child_locations(location))
for location, _ in query_metadata_table.registered_locations
}
return {
location: child_locations
for location, child_locations in six.iteritems(all_locations_with_possible_children)
if child_locations
}


def get_comparable_revisits(query_metadata_table):
"""Return a dict location -> set of revisit locations for that starting location."""
revisit_origins = {
query_metadata_table.get_revisit_origin(location)
for location, _ in query_metadata_table.registered_locations
}

intermediate_result = {
location: set(query_metadata_table.get_all_revisits(location))
for location in revisit_origins
}

return {
location: revisits
for location, revisits in six.iteritems(intermediate_result)
if revisits
}


def compute_child_and_revisit_locations(ir_blocks):
"""Return dicts describing the parent-child and revisit relationships for all query locations.
Args:
ir_blocks: list of IR blocks describing the given query
Returns:
tuple of:
dict mapping parent location -> set of child locations (guaranteed to be non-empty)
dict mapping revisit origin -> set of revisits (possibly empty)
"""
if not ir_blocks:
raise AssertionError(u'Unexpectedly received empty ir_blocks: {}'.format(ir_blocks))

first_block = ir_blocks[0]
if not isinstance(first_block, blocks.QueryRoot):
raise AssertionError(u'Unexpectedly, the first IR block was not a QueryRoot: {} {}'
.format(first_block, ir_blocks))

# These block types do not affect the computed location structure.
no_op_block_types = (
blocks.Filter,
blocks.ConstructResult,
blocks.EndOptional,
blocks.OutputSource,
blocks.CoerceType,
)

current_location = None
traversed_or_recursed_or_folded = False
fold_started_at = None

top_level_locations = set()
parent_location = dict() # location -> parent location
child_locations = dict() # location -> set of child locations
revisits = dict() # location -> set of revisit locations
query_path_to_revisit_origin = dict() # location query path -> its revisit origin

# Walk the IR blocks and reconstruct the query's location structure.
for block in ir_blocks[1:]:
if isinstance(block, (blocks.Traverse, blocks.Fold, blocks.Recurse)):
traversed_or_recursed_or_folded = True
if isinstance(block, blocks.Fold):
fold_started_at = current_location
elif isinstance(block, blocks.Unfold):
current_location = fold_started_at
elif isinstance(block, blocks.MarkLocation):
# Handle optional traversals and backtracks, due to the fact that
# they might drop MarkLocations before and after themselves.
if traversed_or_recursed_or_folded:
block_parent_location = current_location
else:
block_parent_location = parent_location.get(current_location, None)

if block_parent_location is not None:
parent_location[block.location] = block_parent_location
child_locations.setdefault(block_parent_location, set()).add(block.location)
else:
top_level_locations.add(current_location)

current_location = block.location

if isinstance(current_location, helpers.FoldScopeLocation):
revisit_origin = None
elif isinstance(current_location, helpers.Location):
if current_location.query_path not in query_path_to_revisit_origin:
query_path_to_revisit_origin[current_location.query_path] = current_location
revisit_origin = None
else:
revisit_origin = query_path_to_revisit_origin[current_location.query_path]
else:
raise AssertionError(u'Unreachable state reached: {} {}'
.format(current_location, ir_blocks))

if revisit_origin is not None:
revisits.setdefault(revisit_origin, set()).add(current_location)

traversed_or_recursed_or_folded = False
elif isinstance(block, blocks.Backtrack):
current_location = block.location
elif isinstance(block, no_op_block_types):
# These blocks do not affect the computed location structure.
pass
elif isinstance(block, blocks.QueryRoot):
raise AssertionError(u'Unexpectedly encountered a second QueryRoot after the first '
u'IR block: {} {}'.format(block, ir_blocks))
else:
raise AssertionError(u'Unexpected block type encountered: {} {}'
.format(block, ir_blocks))

return child_locations, revisits


class IrGenerationTests(unittest.TestCase):
"""Ensure valid inputs produce correct IR."""

Expand Down

0 comments on commit 97f48cc

Please sign in to comment.