Skip to content

Commit

Permalink
Improved GraphQL DSL 😊. Fixed #12
Browse files Browse the repository at this point in the history
  • Loading branch information
syrusakbary committed Dec 19, 2016
1 parent 8257777 commit 269d2e9
Show file tree
Hide file tree
Showing 3 changed files with 111 additions and 52 deletions.
56 changes: 43 additions & 13 deletions gql/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,44 @@
import six
from graphql.language import ast
from graphql.language.printer import print_ast
from graphql.type import (GraphQLField, GraphQLFieldDefinition, GraphQLList,
from graphql.type import (GraphQLField, GraphQLList,
GraphQLNonNull, GraphQLEnumType)

from .utils import to_camel_case


class DSLSchema(object):
def __init__(self, client):
self.client = client

@property
def schema(self):
return self.client.schema

def __getattr__(self, name):
type_def = self.schema.get_type(name)
return DSLType(type_def)


class DSLType(object):
def __init__(self, type):
self.type = type

def __getattr__(self, name):
formatted_name, field_def = self.get_field(name)
return DSLField(formatted_name, field_def)

def get_field(self, name):
camel_cased_name = to_camel_case(name)

if name in self.type.fields:
return name, self.type.fields[name]

if camel_cased_name in self.type.fields:
return camel_cased_name, self.type.fields[camel_cased_name]

raise KeyError('Field {} doesnt exist in type {}.'.format(name, self.type.name))


def selections(*fields):
for _field in fields:
Expand All @@ -30,9 +65,9 @@ def get_ast_value(value):

class DSLField(object):

def __init__(self, field):
def __init__(self, name, field):
self.field = field
self.ast_field = ast.Field(name=ast.Name(value=field.name), arguments=[])
self.ast_field = ast.Field(name=ast.Name(value=name), arguments=[])
self.selection_set = None

def get(self, *fields):
Expand All @@ -41,21 +76,16 @@ def get(self, *fields):
self.ast_field.selection_set.selections.extend(selections(*fields))
return self

def __call__(self, *args, **kwargs):
return self.get(*args, **kwargs)

def alias(self, alias):
self.ast_field.alias = ast.Name(value=alias)
return self

def get_field_args(self):
if isinstance(self.field, GraphQLFieldDefinition):
# The args will be an array
return {
arg.name: arg for arg in self.field.args
}
return self.field.args

def args(self, **args):
for name, value in args.items():
arg = self.get_field_args().get(name)
arg = self.field.args.get(name)
arg_type_serializer = get_arg_serializer(arg.type)
value = arg_type_serializer(value)
self.ast_field.arguments.append(
Expand All @@ -75,7 +105,7 @@ def __str__(self):


def field(field, **args):
if isinstance(field, (GraphQLField, GraphQLFieldDefinition)):
if isinstance(field, GraphQLField):
return DSLField(field).args(**args)
elif isinstance(field, DSLField):
return field
Expand Down
21 changes: 21 additions & 0 deletions gql/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
import re


# From this response in Stackoverflow
# http://stackoverflow.com/a/19053800/1072990
def to_camel_case(snake_str):
components = snake_str.split('_')
# We capitalize the first letter of each component except the first one
# with the 'title' method and join them together.
return components[0] + "".join(x.title() if x else '_' for x in components[1:])


# From this response in Stackoverflow
# http://stackoverflow.com/a/1176023/1072990
def to_snake_case(name):
s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name)
return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower()


def to_const(string):
return re.sub('[\W|^]+', '_', string).upper()
86 changes: 47 additions & 39 deletions tests/starwars/test_dsl.py
Original file line number Diff line number Diff line change
@@ -1,39 +1,52 @@
from gql import dsl
import pytest

from gql import Client
from gql.dsl import DSLSchema

from .schema import characterInterface, humanType, queryType


# We construct a Simple DSL objects for easy field referencing

class Query(object):
hero = queryType.get_fields()['hero']
human = queryType.get_fields()['human']
# class Query(object):
# hero = queryType.fields['hero']
# human = queryType.fields['human']


# class Character(object):
# id = characterInterface.fields['id']
# name = characterInterface.fields['name']
# friends = characterInterface.fields['friends']
# appears_in = characterInterface.fields['appearsIn']


class Character(object):
id = characterInterface.get_fields()['id']
name = characterInterface.get_fields()['name']
friends = characterInterface.get_fields()['friends']
appears_in = characterInterface.get_fields()['appearsIn']
# class Human(object):
# name = humanType.fields['name']


class Human(object):
name = humanType.get_fields()['name']
from .schema import StarWarsSchema


def test_hero_name_query():
@pytest.fixture
def ds():
client = Client(schema=StarWarsSchema)
ds = DSLSchema(client)
return ds


def test_hero_name_query(ds):
query = '''
hero {
name
}
'''.strip()
query_dsl = dsl.field(Query.hero).get(
Character.name
query_dsl = ds.Query.hero(
ds.Character.name
)
assert query == str(query_dsl)


def test_hero_name_and_friends_query():
def test_hero_name_and_friends_query(ds):
query = '''
hero {
id
Expand All @@ -43,17 +56,17 @@ def test_hero_name_and_friends_query():
}
}
'''.strip()
query_dsl = dsl.field(Query.hero).get(
Character.id,
Character.name,
dsl.field(Character.friends).get(
Character.name,
query_dsl = ds.Query.hero(
ds.Character.id,
ds.Character.name,
ds.Character.friends(
ds.Character.name,
)
)
assert query == str(query_dsl)


def test_nested_query():
def test_nested_query(ds):
query = '''
hero {
name
Expand All @@ -66,27 +79,27 @@ def test_nested_query():
}
}
'''.strip()
query_dsl = dsl.field(Query.hero).get(
Character.name,
dsl.field(Character.friends).get(
Character.name,
Character.appears_in,
dsl.field(Character.friends).get(
Character.name
query_dsl = ds.Query.hero(
ds.Character.name,
ds.Character.friends(
ds.Character.name,
ds.Character.appears_in,
ds.Character.friends(
ds.Character.name
)
)
)
assert query == str(query_dsl)


def test_fetch_luke_query():
def test_fetch_luke_query(ds):
query = '''
human(id: "1000") {
name
}
'''.strip()
query_dsl = dsl.field(Query.human, id="1000").get(
Human.name,
query_dsl = ds.Query.human.args(id="1000").get(
ds.Human.name,
)

assert query == str(query_dsl)
Expand Down Expand Up @@ -153,19 +166,14 @@ def test_fetch_luke_query():
# assert result.data == expected


def test_fetch_luke_aliased():
def test_fetch_luke_aliased(ds):
query = '''
luke: human(id: "1000") {
name
}
'''.strip()
expected = {
'luke': {
'name': 'Luke Skywalker',
}
}
query_dsl = dsl.field(Query.human, id=1000).alias('luke').get(
Character.name,
query_dsl = ds.Query.human.args(id=1000).alias('luke').get(
ds.Character.name,
)
assert query == str(query_dsl)

Expand Down

0 comments on commit 269d2e9

Please sign in to comment.