Skip to content

Commit

Permalink
Merge pull request #53 from kindermax/add-federation-support
Browse files Browse the repository at this point in the history
Add federation support
  • Loading branch information
kindermax committed Jul 21, 2021
2 parents edcd3c1 + 61f7d8f commit fb8e313
Show file tree
Hide file tree
Showing 21 changed files with 1,899 additions and 158 deletions.
159 changes: 159 additions & 0 deletions examples/graphql_federation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
import logging

from typing import (
TypedDict,
)

from flask import Flask, request, jsonify

from hiku.federation.directive import (
Key,
External,
Extends,
)
from hiku.federation.endpoint import FederatedGraphQLEndpoint
from hiku.federation.engine import Engine
from hiku.graph import (
Root,
Field,
Option,
Node,
Link,
Graph,
)
from hiku.types import (
Integer,
TypeRef,
String,
Optional,
Sequence,
)
from hiku.executors.sync import SyncExecutor

log = logging.getLogger(__name__)


class Cart(TypedDict):
id: int
status: str


class CartItem(TypedDict):
id: int
cart_id: int
name: str


def get_by_id(id_, collection):
for item in collection:
if item['id'] == id_:
return item


def find_all_by_id(id_, collection, key='id'):
for item in collection:
if item[key] == id_:
yield item


data = {
'carts': [
Cart(id=1, status='NEW'),
Cart(id=2, status='ORDERED'),
],
'cart_items': [
CartItem(id=10, cart_id=1, name='Ipad'),
CartItem(id=20, cart_id=2, name='Book'),
CartItem(id=21, cart_id=2, name='Pen'),
]
}


def cart_resolver(fields, ids):
for cart_id in ids:
cart = get_by_id(cart_id, data['carts'])
yield [cart[f.name] for f in fields]


def cart_item_resolver(fields, ids):
for item_id in ids:
item = get_by_id(item_id, data['cart_items'])
yield [item[f.name] for f in fields]


def link_cart_items(cart_ids):
for cart_id in cart_ids:
yield [item['id'] for item
in find_all_by_id(cart_id, data['cart_items'], key='cart_id')]


def direct_link_id(opts):
return opts['id']


def ids_resolver(fields, ids):
return [[id_] for id_ in ids]


def direct_link(ids):
return ids


QUERY_GRAPH = Graph([
Node('Order', [
Field('cartId', Integer, ids_resolver,
directives=[External()]),
Link('cart', TypeRef['Cart'], direct_link, requires='cartId'),
], directives=[Key('cartId'), Extends()]),
Node('Cart', [
Field('id', Integer, cart_resolver),
Field('status', String, cart_resolver),
Link('items', Sequence[TypeRef['CartItem']], link_cart_items,
requires='id')
], directives=[Key('id')]),
Node('CartItem', [
Field('id', Integer, cart_item_resolver),
Field('cart_id', Integer, cart_item_resolver),
Field('name', String, cart_item_resolver),
Field('photo', Optional[String], lambda: None, options=[
Option('width', Integer),
Option('height', Integer),
]),
]),
Root([
Link(
'cart',
Optional[TypeRef['Cart']],
direct_link_id,
requires=None,
options=[
Option('id', Integer)
],
),
]),
])


app = Flask(__name__)

graphql_endpoint = FederatedGraphQLEndpoint(
Engine(SyncExecutor()),
QUERY_GRAPH,
)


@app.route('/graphql', methods={'POST'})
def handle_graphql():
data = request.get_json()
result = graphql_endpoint.dispatch(data)
resp = jsonify(result)
return resp


def main():
logging.basicConfig()
app.run(port=5000)


if __name__ == '__main__':
main()
Empty file added hiku/federation/__init__.py
Empty file.
9 changes: 9 additions & 0 deletions hiku/federation/denormalize.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
from collections import deque

from hiku.denormalize.graphql import DenormalizeGraphQL


class DenormalizeEntityGraphQL(DenormalizeGraphQL):
def __init__(self, graph, result, root_type_name):
super().__init__(graph, result, root_type_name)
self._type = deque([graph.__types__[root_type_name]])
53 changes: 53 additions & 0 deletions hiku/federation/directive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,53 @@
class _DirectiveBase:
pass


class Key(_DirectiveBase):
"""
https://www.apollographql.com/docs/federation/federation-spec/#key
"""
def __init__(self, fields):
self.fields = fields

def accept(self, visitor):
return visitor.visit_key_directive(self)


class Provides(_DirectiveBase):
"""
https://www.apollographql.com/docs/federation/federation-spec/#provides
"""
def __init__(self, fields):
self.fields = fields

def accept(self, visitor):
return visitor.visit_provides_directive(self)


class Requires(_DirectiveBase):
"""
https://www.apollographql.com/docs/federation/federation-spec/#requires
"""
def __init__(self, fields):
self.fields = fields

def accept(self, visitor):
return visitor.visit_requires_directive(self)


class External(_DirectiveBase):
"""
https://www.apollographql.com/docs/federation/federation-spec/#external
"""
def accept(self, visitor):
return visitor.visit_external_directive(self)


class Extends(_DirectiveBase):
"""
Apollo Federation supports using an @extends directive in place of extend
type to annotate type references
https://www.apollographql.com/docs/federation/federation-spec/
"""
def accept(self, visitor):
return visitor.visit_extends_directive(self)
154 changes: 154 additions & 0 deletions hiku/federation/endpoint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,154 @@
from abc import abstractmethod
from asyncio import gather
from contextlib import contextmanager
from typing import (
List,
Dict,
Any,
)

from .utils import get_keys
from .introspection import (
FederatedGraphQLIntrospection,
AsyncFederatedGraphQLIntrospection,
is_introspection_query,
extend_with_federation,
)
from .validate import validate

from hiku.denormalize.graphql import DenormalizeGraphQL
from hiku.federation.denormalize import DenormalizeEntityGraphQL
from hiku.endpoint.graphql import (
BaseGraphQLEndpoint,
_type_names,
_switch_graph,
GraphQLError,
_StripQuery,
)
from hiku.graph import Graph
from hiku.query import Node
from hiku.result import Proxy, Reference


def _process_query(graph, query):
stripped_query = _StripQuery().visit(query)
errors = validate(graph, stripped_query)
if errors:
raise GraphQLError(errors=errors)
else:
return stripped_query


def denormalize_entities(
graph: Graph,
query: Node,
result: Proxy,
) -> List[Dict[str, Any]]:

entities_link = query.fields_map['_entities']
node = entities_link.node
representations = entities_link.options['representations']

entities = []
for r in representations:
typename = r['__typename']
for key in get_keys(graph, typename):
if key not in r:
continue
ident = r[key]
result.__ref__ = Reference(typename, ident)
data = DenormalizeEntityGraphQL(
graph, result, typename
).process(node)
entities.append(data)

return entities


class BaseFederatedGraphEndpoint(BaseGraphQLEndpoint):
@abstractmethod
def execute(self, graph, op, ctx):
pass

@abstractmethod
def dispatch(self, data):
pass

@contextmanager
def context(self, op):
yield {}

@staticmethod
def postprocess_result(result: Proxy, graph, op):
if '_service' in op.query.fields_map:
return {'_service': {'sdl': result['sdl']}}
elif '_entities' in op.query.fields_map:
return {
'_entities': denormalize_entities(
graph, op.query, result
)
}

type_name = _type_names[op.type]

data = DenormalizeGraphQL(graph, result, type_name).process(op.query)
if is_introspection_query(op.query):
extend_with_federation(graph, data)
return data


class FederatedGraphQLEndpoint(BaseFederatedGraphEndpoint):
"""Can execute either regular or federated queries.
Handles following fields of federated query:
- _service
- _entities
"""
introspection_cls = FederatedGraphQLIntrospection

def execute(self, graph: Graph, op, ctx):
stripped_query = _process_query(graph, op.query)
result = self.engine.execute(graph, stripped_query, ctx)
return self.postprocess_result(result, graph, op)

def dispatch(self, data):
try:
graph, op = _switch_graph(
data, self.query_graph, self.mutation_graph,
)
with self.context(op) as ctx:
result = self.execute(graph, op, ctx)
return {'data': result}
except GraphQLError as e:
return {'errors': [{'message': e} for e in e.errors]}


class AsyncFederatedGraphQLEndpoint(BaseFederatedGraphEndpoint):
introspection_cls = AsyncFederatedGraphQLIntrospection

async def execute(self, graph: Graph, op, ctx):
stripped_query = _process_query(graph, op.query)
result = await self.engine.execute_async(graph, stripped_query, ctx)
return self.postprocess_result(result, graph, op)

async def dispatch(self, data):
try:
graph, op = _switch_graph(
data, self.query_graph, self.mutation_graph,
)

with self.context(op) as ctx:
result = await self.execute(graph, op, ctx)
return {'data': result}
except GraphQLError as e:
return {'errors': [{'message': e} for e in e.errors]}


class AsyncBatchFederatedGraphQLEndpoint(AsyncFederatedGraphQLEndpoint):
async def dispatch(self, data):
if isinstance(data, list):
return await gather(*(
super().dispatch(item)
for item in data
))

return await super().dispatch(data)

0 comments on commit fb8e313

Please sign in to comment.