In [205]:
from abc import ABCMeta, abstractmethod
from collections import namedtuple
from itertools import chain, repeat
from pprint import pprint
from typing import (
    AbstractSet, Any, Callable, Collection, Dict, Generic, Iterable, Mapping, NamedTuple, 
    Optional, Tuple, TypeVar, Union
)

from funcy.py3 import chunks
from graphql import GraphQLList, GraphQLString
from graphql_compiler.compiler.blocks import (
    Backtrack, CoerceType, ConstructResult, Filter, GlobalOperationsStart, MarkLocation, 
    QueryRoot, Traverse
)
from graphql_compiler.compiler.compiler_entities import BasicBlock, Expression
from graphql_compiler.compiler.expressions import (
    BinaryComposition, ContextField, Literal, LocalField, OutputContextField, Variable
)
from graphql_compiler.compiler.compiler_frontend import IrAndMetadata, graphql_to_ir
from graphql_compiler.compiler.helpers import Location, get_only_element_from_collection
from graphql_compiler.schema import GraphQLDate, GraphQLDateTime, GraphQLDecimal

In [116]:
FilterInfo = namedtuple(
    'FilterInfo',
    ('field_name', 'op_name', 'value'),
)
DataToken = TypeVar('DataToken')


class LineageToken(NamedTuple):
    token: DataToken
    lineage_by_location: Dict[Location, DataToken]

        
def make_empty_lineage_token(token: DataToken):
    return LineageToken(token=token, lineage_by_location=dict())
        

class InterpreterAdapter(Generic[DataToken], metaclass=ABCMeta):
    @abstractmethod
    def get_tokens_of_type(
        self,
        type_name: str, 
        filter_hints: Collection[FilterInfo], 
        required_connection_hints: AbstractSet[str]
    ) -> Iterable[DataToken]:
        pass

    @abstractmethod
    def project_property(
        self,
        tokens: Iterable[DataToken], 
        field_name: str
    ) -> Iterable[Any]:
        pass

    @abstractmethod
    def project_neighbors(
        self,
        tokens: Iterable[LineageToken], 
        direction: str,
        edge_name: str, 
        neighbor_type_hint: Optional[str],
        filter_hints: Collection[FilterInfo],
        required_connection_hints: AbstractSet[str]
    ) -> Iterable[Tuple[LineageToken, DataToken]]:
        pass

    @abstractmethod
    def coerce_to_type(
        self,
        tokens: Iterable[LineageToken], 
        type_name: str
    ) -> Iterable[LineageToken]:
        pass

We use `funcy.chunks` instead of `itertools.tee` to be able to control the "lag" across
the different uses of a given iterable. If we used `itertools.tee`, we'd likely end up exhausting
each iterable before continuing on to the next one, which is explicitly listed as a misuse
of `itertools.tee` in its documentation.

In [203]:
def _handle_local_field(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: LocalField,
    lineages: Iterable[LineageToken],
) -> Iterable[Any]:
    field_name = expression.field_name
    tokens = (
        lineage_token.token
        for lineage_token in lineages
    )
    return adapter.project_property(tokens, field_name)


def _handle_context_field(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: Union[ContextField, OutputContextField],
    lineages: Iterable[LineageToken],
) -> Iterable[Any]:
    if isinstance(expression, OutputContextField):
        lineages = _print_tap('output context field', lineages)
    
    location = expression.location.at_vertex()
    field_name = expression.location.field
    tokens = (
        lineage_token.lineage_by_location[location]
        for lineage_token in lineages
    )
    return adapter.project_property(tokens, field_name)


def _handle_variable(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: Variable,
    lineages: Iterable[LineageToken],
) -> Iterable[Any]:
    variable_value = query_arguments[expression.variable_name[1:]]
    return repeat(variable_value)


def _handle_binary_composition(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: BinaryComposition,
    lineages: Iterable[LineageToken],
) -> Iterable[Any]:
    for lineages_chunk in chunks(1000, lineages):
        left_values = _handle_expression(adapter, query_arguments, expression.left, lineages_chunk)
        right_values = _handle_expression(adapter, query_arguments, expression.right, lineages_chunk)
        
        yield from (
            _apply_operator(expression.operator, left_value, right_value)
            for left_value, right_value in zip(left_values, right_values)
        )

        
def _apply_operator(operator: str, left_value: Any, right_value: Any) -> Any:
    if operator == '=':
        return left_value == right_value
    elif operator == 'contains':
        return right_value in left_value
    else:
        raise NotImplementedError()

In [198]:
def _handle_expression(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    expression: Expression,
    lineages: Iterable[LineageToken],
) -> Iterable[Any]:
    type_to_handler = {
        BinaryComposition: _handle_binary_composition,
        ContextField: _handle_context_field,
        OutputContextField: _handle_context_field,
        LocalField: _handle_local_field,
        Variable: _handle_variable,
    }
    expression_type = type(expression)
    return type_to_handler[expression_type](adapter, query_arguments, expression, lineages)

In [147]:
def _zip_dict_iterables(dict_of_iterables: Dict[str, Iterable[Any]]) -> Iterable[Dict[str, Any]]:
    """Zip a dict of iterables into an iterable of dicts containing the values of the iterables."""
    key_value_iterables = [
        zip(repeat(key), values)
        for key, values in dict_of_iterables.items()
    ]
    return (
        dict(set_of_values)
        for set_of_values in zip(*key_value_iterables)
    )

In [186]:
def _handle_construct_result(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    block: ConstructResult,
    lineages: Iterable[LineageToken],
) -> Iterable[Dict[str, Any]]:
    output_fields = block.fields
    
    for lineages_chunk in chunks(1000, lineages):        
        output_name_to_expression_values = {
            output_name: _handle_expression(adapter, query_arguments, expression, lineages_chunk)
            for output_name, expression in output_fields.items()
        }
        yield from _zip_dict_iterables(output_name_to_expression_values)
    

def _handle_filter(
    adapter: InterpreterAdapter[DataToken], 
    query_arguments: Dict[str, Any],
    block: Filter,
    lineages: Iterable[LineageToken],
) -> Iterable[LineageToken]:
    predicate = block.predicate
    for lineages_chunk in chunks(1000, lineages):
        yield from (
            lineage
            for lineage, predicate_value in zip(
                lineages_chunk,
                _handle_expression(adapter, query_arguments, predicate, lineages_chunk)
            )
            if predicate_value
        )
    
    
def _handle_coerce_type(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: CoerceType,
    lineages: Iterable[LineageToken],
) -> Iterable[LineageToken]:
    coercion_type = get_only_element_of_collection(block.target_class)
    return adapter.coerce_to_type(lineages, coercion_type)
    

def _handle_mark_location(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: MarkLocation,
    lineages: Iterable[LineageToken],
) -> Iterable[LineageToken]:
    current_location = block.location
    for lineage_token in lineages:
        lineage_by_location = dict(lineage_token.lineage_by_location)
        lineage_by_location[current_location] = lineage_token.token
        yield LineageToken(
            token=lineage_token.token,
            lineage_by_location=lineage_by_location,
        )
        

def _handle_backtrack(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: Backtrack,
    lineages: Iterable[LineageToken],
) -> Iterable[LineageToken]:
    backtrack_location = block.location
    for lineage_token in lineages:
        yield LineageToken(
            token=lineage_token.lineage_by_location[backtrack_location],
            lineage_by_location=lineage_token.lineage_by_location,
        )
        
        
def _handle_traverse(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: Traverse,
    lineages: Iterable[LineageToken],
) -> Iterable[LineageToken]:
    if block.optional:
        raise NotImplementedError()
    
    neighbor_batches = adapter.project_neighbors(
        lineages, block.direction, block.edge_name, None, [], set())
    for lineage, neighbor_token in neighbor_batches:
        yield LineageToken(token=neighbor_token, lineage_by_location=lineage.lineage_by_location)

In [172]:
def _handle_block(
    adapter: InterpreterAdapter[DataToken],
    query_arguments: Dict[str, Any],
    block: BasicBlock,
    lineages: Iterable[LineageToken],
) -> Iterable[LineageToken]:
    no_op_types = (GlobalOperationsStart,)
    if isinstance(block, no_op_types):
        return lineages
    
    lineages = _print_tap('pre: ' + str(block), lineages)
    
    handler_functions = {        
        CoerceType: _handle_coerce_type,
        Filter: _handle_filter,
        MarkLocation: _handle_mark_location,
        Traverse: _handle_traverse,
        Backtrack: _handle_backtrack,
    }
    return handler_functions[type(block)](adapter, query_arguments, block, lineages)


In [162]:
def _print_tap(info: str, lineages: Iterable[LineageToken]) -> Iterable[LineageToken]:
    return lineages
#     print('\n')
#     unique_id = hash(info)
#     print(unique_id, info)
#     for lineage_chunk in chunks(100, lineages):
#         for lineage in lineage_chunk:
#             pprint((unique_id, lineage))
#             yield lineage
        

In [184]:
def interpret_ir(
    adapter: InterpreterAdapter[DataToken], 
    ir_and_metadata: IrAndMetadata, 
    query_arguments: Dict[str, Any]
) -> Iterable[Dict[str, Any]]:
    ir_blocks = ir_and_metadata.ir_blocks
    query_metadata_table = ir_and_metadata.query_metadata_table
    
    if not ir_blocks:
        raise AssertionError()
        
    first_block = ir_blocks[0]
    if not isinstance(first_block, QueryRoot):
        raise AssertionError()
        
    last_block = ir_blocks[-1]
    if not isinstance(last_block, ConstructResult):
        raise AssertionError()
        
    middle_blocks = ir_blocks[1:-1]
        
    start_class = get_only_element_from_collection(first_block.start_class)
    
    current_lineages = (
        make_empty_lineage_token(token)
        for token in adapter.get_tokens_of_type(start_class, [], set())
    )
    
    current_lineages = _print_tap('starting lineages', current_lineages)
    
    for block in middle_blocks:
        current_lineages = _handle_block(adapter, query_arguments, block, current_lineages)
        
    current_lineages = _print_tap('ending lineages', current_lineages)
        
    return _handle_construct_result(
        adapter, query_arguments, last_block, current_lineages)
    

In [121]:
vertices = {
    'Animal': [
        {'name': 'Scooby Doo', 'uuid': '1001'},
        {'name': 'Hedwig', 'uuid': '1002'},
        {'name': 'Beethoven', 'uuid': '1003'},
        {'name': 'Pongo', 'uuid': '1004'},
        {'name': 'Perdy', 'uuid': '1005'},
        {'name': 'Dipstick', 'uuid': '1006'},
        {'name': 'Dottie', 'uuid': '1007'},
        {'name': 'Domino', 'uuid': '1008'},
        {'name': 'Little Dipper', 'uuid': '1009'},
        {'name': 'Oddball', 'uuid': '1010'},
    ],
}
edges = {
    'Animal_ParentOf': [
        ('1004', '1006'),
        ('1005', '1006'),
        ('1006', '1008'),
        ('1006', '1009'),
        ('1006', '1010'),
        ('1007', '1008'),
        ('1007', '1009'),
        ('1007', '1010'),
    ],
}

vertices_by_uuid = {
    vertex['uuid']: vertex
    for vertex in chain.from_iterable(vertices.values())
}


class InMemoryAdapter(InterpreterAdapter[dict]):
    def get_tokens_of_type(
        self,
        type_name: str, 
        filter_hints: Collection[FilterInfo], 
        required_connection_hints: AbstractSet[str]
    ) -> Iterable[dict]:
        return vertices[type_name]

    def project_property(
        self,
        tokens: Iterable[dict], 
        field_name: str
    ) -> Iterable[Any]:
        return (
            token[field_name]
            for token in tokens
        )

    def project_neighbors(
        self,
        tokens: Iterable[LineageToken], 
        direction: str,
        edge_name: str, 
        neighbor_type_hint: Optional[str],
        filter_hints: Collection[FilterInfo],
        required_connection_hints: AbstractSet[str]
    ) -> Iterable[Tuple[LineageToken, DataToken]]:
        edge_info = edges[edge_name]
        for lineage in tokens:
            uuid = lineage.token['uuid']
            if direction == 'out':
                yield from (
                    (lineage, vertices_by_uuid[destination_uuid])
                    for source_uuid, destination_uuid in edge_info
                    if source_uuid == uuid
                )
            elif direction == 'in':
                yield from (
                    (lineage, vertices_by_uuid[source_uuid])
                    for source_uuid, destination_uuid in edge_info
                    if destination_uuid == uuid
                )
            else:
                raise AssertionError()

    def coerce_to_type(
        self,
        tokens: Iterable[LineageToken], 
        type_name: str
    ) -> Iterable[LineageToken]:
        # no-op
        return tokens

In [192]:
base_location = Location(('Animal',))
blocks = [
    QueryRoot({'Animal'}),
    MarkLocation(base_location),
    GlobalOperationsStart(),
    ConstructResult({
        'animal_name': OutputContextField(
            base_location.navigate_to_field('name'), GraphQLString)
    }),
]
query_arguments = {}
input_metadata = {}
output_metadata = {}
query_metadata_table = None

In [191]:
base_location = Location(('Animal',))
child_location = base_location.navigate_to_subpath('out_Animal_ParentOf')

blocks = [
    QueryRoot({'Animal'}),
    MarkLocation(base_location),
    Traverse('out', 'Animal_ParentOf'),
    MarkLocation(child_location),
    Backtrack(base_location),
    GlobalOperationsStart(),
    ConstructResult({
        'parent_name': OutputContextField(
            base_location.navigate_to_field('name'), GraphQLString),
        'child_name': OutputContextField(
            child_location.navigate_to_field('name'), GraphQLString),
    }),
]
query_arguments = {}
input_metadata = {}
output_metadata = {}
query_metadata_table = None

In [201]:
base_location = Location(('Animal',))
child_location = base_location.navigate_to_subpath('out_Animal_ParentOf')

blocks = [
    QueryRoot({'Animal'}),
    MarkLocation(base_location),
    Traverse('out', 'Animal_ParentOf'),
    Filter(
        BinaryComposition(
            u'=',
            LocalField('name', GraphQLString),
            Variable('$child', GraphQLString),
        )
    ),
    MarkLocation(child_location),
    Backtrack(base_location),
    GlobalOperationsStart(),
    ConstructResult({
        'parent_name': OutputContextField(
            base_location.navigate_to_field('name'), GraphQLString),
        'child_name': OutputContextField(
            child_location.navigate_to_field('name'), GraphQLString),
    }),
]
query_arguments = {
    'child': 'Domino',
}
input_metadata = {
    'child': GraphQLString
}
output_metadata = {}
query_metadata_table = None

In [208]:
base_location = Location(('Animal',))
child_location = base_location.navigate_to_subpath('out_Animal_ParentOf')

blocks = [
    QueryRoot({'Animal'}),
    MarkLocation(base_location),
    Traverse('out', 'Animal_ParentOf'),
    Filter(
        BinaryComposition(
            u'contains',
            Variable('$child_names', GraphQLString),
            LocalField('name', GraphQLString),
        )
    ),
    MarkLocation(child_location),
    Backtrack(base_location),
    GlobalOperationsStart(),
    ConstructResult({
        'parent_name': OutputContextField(
            base_location.navigate_to_field('name'), GraphQLString),
        'child_name': OutputContextField(
            child_location.navigate_to_field('name'), GraphQLString),
    }),
]
query_arguments = {
    'child_names': ['Domino', 'Dipstick', 'Oddball'],
}
input_metadata = {
    'child_names': GraphQLList(GraphQLString)
}
output_metadata = {}
query_metadata_table = None

In [209]:
ir_and_metadata = IrAndMetadata(
    blocks, input_metadata, output_metadata, query_metadata_table
)
result = list(interpret_ir(InMemoryAdapter(), ir_and_metadata, query_arguments))
result

[{'parent_name': 'Pongo', 'child_name': 'Dipstick'},
 {'parent_name': 'Perdy', 'child_name': 'Dipstick'},
 {'parent_name': 'Dipstick', 'child_name': 'Domino'},
 {'parent_name': 'Dipstick', 'child_name': 'Oddball'},
 {'parent_name': 'Dottie', 'child_name': 'Domino'},
 {'parent_name': 'Dottie', 'child_name': 'Oddball'}]