In [1]:
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from itertools import chain, repeat
from pprint import pprint
from typing import (
    AbstractSet, Any, Callable, Collection, Dict, Generator, Generic, Iterable, List, Mapping,
    NamedTuple, Optional, Tuple, TypeVar, Union
)

from graphql import GraphQLList, GraphQLString, parse
from graphql.utils.build_ast_schema import build_ast_schema
from graphql_compiler.compiler.compiler_frontend import graphql_to_ir
from graphql_compiler.compiler.blocks import (
    Backtrack, CoerceType, ConstructResult, Filter, GlobalOperationsStart, MarkLocation, 
    QueryRoot, Traverse
)
from graphql_compiler.compiler.compiler_entities import BasicBlock, Expression
from graphql_compiler.compiler.expressions import (
    BinaryComposition, ContextField, Literal, LocalField, OutputContextField, Variable
)
from graphql_compiler.compiler.compiler_frontend import IrAndMetadata, graphql_to_ir
from graphql_compiler.compiler.helpers import Location, get_only_element_from_collection
from graphql_compiler.schema import GraphQLDate, GraphQLDateTime, GraphQLDecimal
from graphql_compiler.tests.test_helpers import SCHEMA_TEXT

In [2]:
FilterInfo = namedtuple(
    'FilterInfo',
    ('field_name', 'op_name', 'value'),
)
DataToken = TypeVar('DataToken')


class DataContext(Generic[DataToken]):
    
    __slots__ = (
        'is_inactive',
        'current_token',
        'token_at_location',
        'expression_stack',
    )
    
    def __init__(
        self, 
        is_inactive: bool, 
        current_token: Optional[DataToken], 
        token_at_location: Dict[Location, Optional[DataToken]], 
        expression_stack: List[Any],
    ):
        self.is_inactive = is_inactive
        self.current_token = current_token
        self.token_at_location = token_at_location
        self.expression_stack = expression_stack

    @staticmethod
    def make_empty_context_from_token(token: DataToken) -> 'DataContext':
        return DataContext(False, token, dict(), [])
    
    def push_value_onto_stack(self, value: Any) -> 'DataContext':
        self.expression_stack.append(value)
        return self  # for chaining
    
    def peek_value_on_stack(self) -> Any:
        return self.expression_stack[-1]
        
    def pop_value_from_stack(self) -> Any:
        return self.expression_stack.pop()
    
    def get_context_for_location(self, location: Location) -> 'DataContext':
        return DataContext(
            False, 
            self.token_at_location[location], 
            self.token_at_location, 
            list(self.expression_stack),
        )
        

class InterpreterAdapter(Generic[DataToken], metaclass=ABCMeta):
    @abstractmethod
    def get_tokens_of_type(
        self,
        type_name: str, 
        **hints
    ) -> Iterable[DataToken]:
        pass

    @abstractmethod
    def project_property(
        self,
        data_contexts: Iterable[DataContext], 
        field_name: str,
        **hints
    ) -> Iterable[Tuple[DataContext, Any]]:
        pass

    @abstractmethod
    def project_neighbors(
        self,
        data_contexts: Iterable[DataContext], 
        direction: str,
        edge_name: str, 
        **hints
    ) -> Iterable[Tuple[DataContext, Iterable[DataToken]]]:
        # If using a generator instead of a list for the Iterable[DataToken] part,
        # be careful -- generators are not closures! Make sure any state you pull into
        # the generator from the outside does not change, or that bug will be hard to find.
        # Remember: it's always safer to use a function to produce the generator, since
        # that will explicitly preserve all the external values passed into it.
        pass

    @abstractmethod
    def can_coerce_to_type(
        self,
        data_contexts: Iterable[DataContext], 
        type_name: str,
        **hints
    ) -> Iterable[Tuple[DataContext, bool]]:
        pass

In [3]:
def _apply_operator(operator: str, left_value: Any, right_value: Any) -> Any:
    if operator == '=':
        return left_value == right_value
    elif operator == 'contains':
        return right_value in left_value
    else:
        raise NotImplementedError()

In [4]:
def _push_values_onto_data_context_stack(
    contexts_and_values: Iterable[Tuple[DataContext, Any]]
) -> Iterable[DataContext]:
    return (
        data_context.push_value_onto_stack(value)
        for data_context, value in contexts_and_values
    )


def _evaluate_binary_composition(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    expression: Expression,
    data_contexts: Iterable[DataContext],
) -> Iterable[Tuple[DataContext, Any]]:
    data_contexts = _push_values_onto_data_context_stack(
        _evaluate_expression(adapter, query_arguments, expression.left, data_contexts)
    )
    data_contexts = _push_values_onto_data_context_stack(
        _evaluate_expression(adapter, query_arguments, expression.right, data_contexts)
    )
    
    for data_context in data_contexts:
        # N.B.: The left sub-expression is evaluated first, therefore its value in the stack
        #       is *below* the value of the right sub-expression.
        #       These two lines cannot be inlined into the _apply_operator() call since
        #       the popping order there would be incorrectly reversed.
        right_value = data_context.pop_value_from_stack()
        left_value = data_context.pop_value_from_stack()
        final_expression_value = _apply_operator(expression.operator, left_value, right_value)
        yield (data_context, final_expression_value)

        
def _evaluate_local_field(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: LocalField,
    data_contexts: Iterable[DataContext],
) -> Iterable[Tuple[DataContext, Any]]:
    field_name = expression.field_name
    return adapter.project_property(data_contexts, field_name)


def _evaluate_context_field(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: Union[ContextField, OutputContextField],
    data_contexts: Iterable[DataContext],
) -> Iterable[Tuple[DataContext, Any]]:
    location = expression.location.at_vertex()
    field_name = expression.location.field
    
    moved_contexts = (
        data_context.get_context_for_location(location).push_value_onto_stack(data_context)
        for data_context in data_contexts
    )
    
    return (
        (moved_data_context.pop_value_from_stack(), value)
        for moved_data_context, value in adapter.project_property(moved_contexts, field_name)
    )

    
def _evaluate_variable(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: Variable,
    data_contexts: Iterable[DataContext],
) -> Iterable[Any]:
    variable_value = query_arguments[expression.variable_name[1:]]
    return (
        (data_context, variable_value)
        for data_context in data_contexts
    )


def _evaluate_expression(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    expression: Expression,
    data_contexts: Iterable[DataContext],
) -> Iterable[Tuple[DataContext, Any]]:
    type_to_handler = {
        BinaryComposition: _evaluate_binary_composition,
        ContextField: _evaluate_context_field,
        OutputContextField: _evaluate_context_field,
        LocalField: _evaluate_local_field,
        Variable: _evaluate_variable,
    }
    expression_type = type(expression)
    return type_to_handler[expression_type](adapter, query_arguments, expression, data_contexts)

In [5]:
def _handle_filter(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    block: Filter,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    predicate = block.predicate
    
    # TODO(predrag): Handle the "filters depending on missing optional values pass" rule.
    
    yield from (
        data_context
        for data_context, predicate_value in _evaluate_expression(
            adapter, query_arguments, predicate, data_contexts
        )
        if predicate_value
    )
    

def _handle_traverse(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: Traverse,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    if block.optional:
        raise NotImplementedError()
    
    neighbor_data = adapter.project_neighbors(data_contexts, block.direction, block.edge_name)
    for data_context, neighbor_tokens in neighbor_data:
        yield from (
            # TODO(predrag): Make a helper staticmethod on DataContext for this.
            DataContext(
                False, neighbor_token, 
                data_context.token_at_location, list(data_context.expression_stack)
            )
            for neighbor_token in neighbor_tokens
        )

In [6]:
def _produce_output(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    output_name: str,
    output_expression: Expression,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    data_contexts = _print_tap(
        'outputting ' + output_name, data_contexts)
    
    contexts_and_values = _evaluate_expression(
        adapter, query_arguments, output_expression, data_contexts)
    
    for data_context, value in contexts_and_values:
        data_context.peek_value_on_stack()[output_name] = value
        yield data_context
    

def _handle_construct_result(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    block: ConstructResult,
    data_contexts: Iterable[DataContext],
) -> Iterable[Dict[str, Any]]:
    output_fields = block.fields
    
    data_contexts = (
        data_context.push_value_onto_stack(dict())
        for data_context in data_contexts
    )
    
    for output_name, output_expression in output_fields.items():
        data_contexts = _produce_output(
            adapter, query_arguments, output_name, output_expression, data_contexts)
        
    return (
        data_context.pop_value_from_stack()
        for data_context in data_contexts
    )

In [7]:
def _handle_coerce_type(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: CoerceType,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    coercion_type = get_only_element_of_collection(block.target_class)
    return (
        data_context
        for data_context, can_coerce in adapter.can_coerce_to_type(data_contexts, coercion_type)
        if can_coerce or data_context.current_token is None
    )
    

def _handle_mark_location(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: MarkLocation,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    current_location = block.location
    for data_context in data_contexts:
        token_at_location = dict(data_context.token_at_location)
        token_at_location[current_location] = data_context.current_token
        yield DataContext(
            False,  # TODO(predrag): This is almost certainly wrong, revisit and extract into staticmethod.
            data_context.current_token,
            token_at_location,
            list(data_context.expression_stack),
        )
        

def _handle_backtrack(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: Backtrack,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    backtrack_location = block.location
    for data_context in data_contexts:
        yield DataContext(
            False,  # TODO(predrag): This is almost certainly wrong, revisit and extract into staticmethod.
            data_context.token_at_location[backtrack_location],
            data_context.token_at_location,
            list(data_context.expression_stack),
        )

In [8]:
def _handle_block(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: BasicBlock,
    data_contexts: Iterable[DataContext],
) -> Iterable[DataContext]:
    no_op_types = (GlobalOperationsStart,)
    if isinstance(block, no_op_types):
        return data_contexts
    
    data_contexts = _print_tap('pre: ' + str(block), data_contexts)
    
    handler_functions = {        
        CoerceType: _handle_coerce_type,
        Filter: _handle_filter,
        MarkLocation: _handle_mark_location,
        Traverse: _handle_traverse,
        Backtrack: _handle_backtrack,
    }
    return handler_functions[type(block)](adapter, query_arguments, block, data_contexts)


In [9]:
def _print_tap(info: str, data_contexts: Iterable[DataContext]) -> Iterable[DataContext]:
    return data_contexts
#     print('\n')
#     unique_id = hash(info)
#     print(unique_id, info)
#     from funcy.py3 import chunks
#     for context_chunk in chunks(100, data_contexts):
#         for context in context_chunk:
#             pprint((unique_id, context))
#             yield context
        

In [10]:
def interpret_ir(
    adapter: InterpreterAdapter[DataToken], 
    ir_and_metadata: IrAndMetadata, 
    query_arguments: Dict[str, Any]
) -> Iterable[Dict[str, Any]]:
    ir_blocks = ir_and_metadata.ir_blocks
    query_metadata_table = ir_and_metadata.query_metadata_table
    
    if not ir_blocks:
        raise AssertionError()
        
    first_block = ir_blocks[0]
    if not isinstance(first_block, QueryRoot):
        raise AssertionError()
        
    last_block = ir_blocks[-1]
    if not isinstance(last_block, ConstructResult):
        raise AssertionError()
        
    middle_blocks = ir_blocks[1:-1]
        
    start_class = get_only_element_from_collection(first_block.start_class)
    
    current_data_contexts = (
        DataContext.make_empty_context_from_token(token)
        for token in adapter.get_tokens_of_type(start_class)
    )
    
    current_data_contexts = _print_tap('starting contexts', current_data_contexts)
    
    for block in middle_blocks:
        current_data_contexts = _handle_block(
            adapter, query_arguments, block, current_data_contexts)
        
    current_data_contexts = _print_tap('ending contexts', current_data_contexts)
        
    return _handle_construct_result(
        adapter, query_arguments, last_block, current_data_contexts)
    

In [11]:
vertices = {
    'Animal': [
        {'name': 'Scooby Doo', 'uuid': '1001'},
        {'name': 'Hedwig', 'uuid': '1002'},
        {'name': 'Beethoven', 'uuid': '1003'},
        {'name': 'Pongo', 'uuid': '1004'},
        {'name': 'Perdy', 'uuid': '1005'},
        {'name': 'Dipstick', 'uuid': '1006'},
        {'name': 'Dottie', 'uuid': '1007'},
        {'name': 'Domino', 'uuid': '1008'},
        {'name': 'Little Dipper', 'uuid': '1009'},
        {'name': 'Oddball', 'uuid': '1010'},
    ],
}
edges = {
    'Animal_ParentOf': [
        ('1004', '1006'),
        ('1005', '1006'),
        ('1006', '1008'),
        ('1006', '1009'),
        ('1006', '1010'),
        ('1007', '1008'),
        ('1007', '1009'),
        ('1007', '1010'),
    ],
}

vertices_by_uuid = {
    vertex['uuid']: vertex
    for vertex in chain.from_iterable(vertices.values())
}


class InMemoryAdapter(InterpreterAdapter[dict]):
    def get_tokens_of_type(
        self,
        type_name: str, 
        **hints
    ) -> Iterable[dict]:
        return vertices[type_name]

    def project_property(
        self,
        data_contexts: Iterable[DataContext], 
        field_name: str,
        **hints
    ) -> Iterable[Tuple[DataContext, Any]]:
        for data_context in data_contexts:
            current_token = data_context.current_token
            current_value = current_token[field_name] if current_token is not None else None
            yield (data_context, current_value)

    def project_neighbors(
        self,
        data_contexts: Iterable[DataContext], 
        direction: str,
        edge_name: str, 
        **hints
    ) -> Iterable[Tuple[DataContext, Iterable[DataToken]]]:
        edge_info = edges[edge_name]
        
        for data_context in data_contexts:
            neighbor_tokens = []
            current_token = data_context.current_token
            if current_token is not None:
                uuid = current_token['uuid']
                if direction == 'out':
                    neighbor_tokens = [
                        vertices_by_uuid[destination_uuid]
                        for source_uuid, destination_uuid in edge_info
                        if source_uuid == uuid
                    ]
                elif direction == 'in':
                    neighbor_tokens = [
                        vertices_by_uuid[destination_uuid]
                        for source_uuid, destination_uuid in edge_info
                        if destination_uuid == uuid
                    ]
                else:
                    raise AssertionError()
                
            yield (data_context, neighbor_tokens)

    def can_coerce_to_type(
        self,
        data_contexts: Iterable[DataContext], 
        type_name: str,
        **hints
    ) -> Iterable[Tuple[DataContext, bool]]:
        # TODO(predrag): See if a redesign can make this be a no-op again.
        return zip(data_contexts, repeat(True))

In [12]:
schema = build_ast_schema(parse(SCHEMA_TEXT))

In [13]:
query = '''
{
    Animal {
        name @output(out_name: "animal_name")
        uuid @output(out_name: "animal_uuid")
    }
}
'''
query_arguments = {}

In [14]:
query = '''
{
    Animal {
        name @output(out_name: "parent_name")

        out_Animal_ParentOf {
            name @output(out_name: "child_name")
        }
    }
}
'''
query_arguments = {}

In [18]:
query = '''
{
    Animal {
        name @output(out_name: "parent_name")

        out_Animal_ParentOf {
            name @filter(op_name: "in_collection", value: ["$child_names"])
                 @output(out_name: "child_name")
        }
    }
}
'''
query_arguments = {
    "child_names": ['Domino', 'Dipstick', 'Oddball'],
}

In [16]:
query = '''
{
    Animal {
        name @output(out_name: "grandparent_name")

        out_Animal_ParentOf {
            name @output(out_name: "parent_name")
            
            out_Animal_ParentOf {
                name @output(out_name: "child_name")
            }
        }
    }
}
'''
query_arguments = {}

In [19]:
ir_and_metadata = graphql_to_ir(schema, query)
result = list(interpret_ir(InMemoryAdapter(), ir_and_metadata, query_arguments))
result

[{'parent_name': 'Pongo', 'child_name': 'Dipstick'},
 {'parent_name': 'Perdy', 'child_name': 'Dipstick'},
 {'parent_name': 'Dipstick', 'child_name': 'Domino'},
 {'parent_name': 'Dipstick', 'child_name': 'Oddball'},
 {'parent_name': 'Dottie', 'child_name': 'Domino'},
 {'parent_name': 'Dottie', 'child_name': 'Oddball'}]