In [33]:
from copy import copy, deepcopy

from graphql import build_ast_schema, parse, print_ast, validate
from graphql.language.ast import (
    Document, Field, FragmentDefinition, FragmentSpread, InlineFragment, OperationDefinition, 
    SelectionSet
)
from graphql.language.visitor import Visitor, visit
from pprint import pprint

from graphql_compiler.tests.test_helpers import SCHEMA_TEXT

In [3]:
schema = build_ast_schema(parse(SCHEMA_TEXT))

In [44]:
query = '''
{
    Animal {
        name @output(out_name: "parent_name")
        
        out_Animal_ParentOf {
            uuid @output(out_name: "child_id")
            ...commonFields
            color @output(out_name: "child_color")
        }
    }
}

fragment commonFields on Animal {
    name @output(out_name: "animal_name")
    birthday @output(out_name: "animal_birthday")
    
    out_Entity_Related {
        name @output(out_name: "related_entity")
    }
}
'''
parsed_ast = parse(query)
validation_errors = validate(schema, parsed_ast)
validation_errors

[]

In [45]:
pprint(parsed_ast.definitions)

[OperationDefinition(operation='query', name=None, variable_definitions=None, directives=[], selection_set=SelectionSet(selections=[Field(alias=None, name=Name(value='Animal'), arguments=[], directives=[], selection_set=SelectionSet(selections=[Field(alias=None, name=Name(value='name'), arguments=[], directives=[Directive(name=Name(value='output'), arguments=[Argument(name=Name(value='out_name'), value=StringValue(value='parent_name'))])], selection_set=None), Field(alias=None, name=Name(value='out_Animal_ParentOf'), arguments=[], directives=[], selection_set=SelectionSet(selections=[Field(alias=None, name=Name(value='uuid'), arguments=[], directives=[Directive(name=Name(value='output'), arguments=[Argument(name=Name(value='out_name'), value=StringValue(value='child_id'))])], selection_set=None), FragmentSpread(name=Name(value='commonFields'), directives=[]), Field(alias=None, name=Name(value='color'), arguments=[], directives=[Directive(name=Name(value='output'), arguments=[Argument(n

In [46]:
def _get_query_and_fragment_definitions(query_ast):
    query_operation_ast = None
    fragment_definition_asts = {}
    
    for definition_ast in query_ast.definitions:
        if isinstance(definition_ast, OperationDefinition):
            assert definition_ast.operation == 'query'
            query_operation_ast = definition_ast
        elif isinstance(definition_ast, FragmentDefinition):
            fragment_name = definition_ast.name.value
            fragment_definition_asts[fragment_name] = definition_ast.selection_set.selections
        else:
            pprint(definition_ast)
            assert False
    
    return query_operation_ast, fragment_definition_asts


class FragmentSpreadExpansionVisitor(Visitor):
    def __init__(self, fragment_selections):
        super(Visitor, self).__init__()
        self.fragment_selections = fragment_selections
        
    def enter_Field(self, node, *args):
        made_changes = False
        if node.selection_set is not None:
            new_selections = []
            for selection_ast in node.selection_set.selections:
                if isinstance(selection_ast, FragmentSpread):
                    made_changes = True
                    fragment_name = selection_ast.name.value
                    fragment_selections = self.fragment_selections[fragment_name]
                    
                    # TODO(predrag): Handle fragment field merging. We probably want the same
                    #                logic as https://github.com/graphql/graphql-spec/issues/399
                    #                The current code will break if both the query definition
                    #                and the fragment use the same edge.
                    new_selections.extend(deepcopy(fragment_selections))
                else:
                    new_selections.append(selection_ast)
        
        if not made_changes:
            return None
        
        new_selections.sort(key=lambda x: x.selection_set is not None)
        new_field = copy(node)
        new_field.selection_set = SelectionSet(new_selections)
        return new_field
            

def expand_fragment_spreads(query_ast):
    query_operation_ast, fragment_selections = _get_query_and_fragment_definitions(query_ast)
    
    expansion_visitor = FragmentSpreadExpansionVisitor(fragment_selections)
    return visit(query_operation_ast, expansion_visitor)

In [47]:
expanded_ast = expand_fragment_spreads(parsed_ast)
print(print_ast(expanded_ast))

{
  Animal {
    name @output(out_name: "parent_name")
    out_Animal_ParentOf {
      uuid @output(out_name: "child_id")
      name @output(out_name: "animal_name")
      birthday @output(out_name: "animal_birthday")
      color @output(out_name: "child_color")
      out_Entity_Related {
        name @output(out_name: "related_entity")
      }
    }
  }
}
