diff --git a/gql/dsl.py b/gql/dsl.py index 57129c54..aa4f1559 100644 --- a/gql/dsl.py +++ b/gql/dsl.py @@ -5,9 +5,44 @@ import six from graphql.language import ast from graphql.language.printer import print_ast -from graphql.type import (GraphQLField, GraphQLFieldDefinition, GraphQLList, +from graphql.type import (GraphQLField, GraphQLList, GraphQLNonNull, GraphQLEnumType) +from .utils import to_camel_case + + +class DSLSchema(object): + def __init__(self, client): + self.client = client + + @property + def schema(self): + return self.client.schema + + def __getattr__(self, name): + type_def = self.schema.get_type(name) + return DSLType(type_def) + + +class DSLType(object): + def __init__(self, type): + self.type = type + + def __getattr__(self, name): + formatted_name, field_def = self.get_field(name) + return DSLField(formatted_name, field_def) + + def get_field(self, name): + camel_cased_name = to_camel_case(name) + + if name in self.type.fields: + return name, self.type.fields[name] + + if camel_cased_name in self.type.fields: + return camel_cased_name, self.type.fields[camel_cased_name] + + raise KeyError('Field {} doesnt exist in type {}.'.format(name, self.type.name)) + def selections(*fields): for _field in fields: @@ -30,9 +65,9 @@ def get_ast_value(value): class DSLField(object): - def __init__(self, field): + def __init__(self, name, field): self.field = field - self.ast_field = ast.Field(name=ast.Name(value=field.name), arguments=[]) + self.ast_field = ast.Field(name=ast.Name(value=name), arguments=[]) self.selection_set = None def get(self, *fields): @@ -41,21 +76,16 @@ def get(self, *fields): self.ast_field.selection_set.selections.extend(selections(*fields)) return self + def __call__(self, *args, **kwargs): + return self.get(*args, **kwargs) + def alias(self, alias): self.ast_field.alias = ast.Name(value=alias) return self - def get_field_args(self): - if isinstance(self.field, GraphQLFieldDefinition): - # The args will be an array - return { - arg.name: arg for arg in self.field.args - } - return self.field.args - def args(self, **args): for name, value in args.items(): - arg = self.get_field_args().get(name) + arg = self.field.args.get(name) arg_type_serializer = get_arg_serializer(arg.type) value = arg_type_serializer(value) self.ast_field.arguments.append( @@ -75,7 +105,7 @@ def __str__(self): def field(field, **args): - if isinstance(field, (GraphQLField, GraphQLFieldDefinition)): + if isinstance(field, GraphQLField): return DSLField(field).args(**args) elif isinstance(field, DSLField): return field diff --git a/gql/utils.py b/gql/utils.py new file mode 100644 index 00000000..ae8ceffe --- /dev/null +++ b/gql/utils.py @@ -0,0 +1,21 @@ +import re + + +# From this response in Stackoverflow +# http://stackoverflow.com/a/19053800/1072990 +def to_camel_case(snake_str): + components = snake_str.split('_') + # We capitalize the first letter of each component except the first one + # with the 'title' method and join them together. + return components[0] + "".join(x.title() if x else '_' for x in components[1:]) + + +# From this response in Stackoverflow +# http://stackoverflow.com/a/1176023/1072990 +def to_snake_case(name): + s1 = re.sub('(.)([A-Z][a-z]+)', r'\1_\2', name) + return re.sub('([a-z0-9])([A-Z])', r'\1_\2', s1).lower() + + +def to_const(string): + return re.sub('[\W|^]+', '_', string).upper() diff --git a/tests/starwars/test_dsl.py b/tests/starwars/test_dsl.py index 5a88ba80..d19e19b3 100644 --- a/tests/starwars/test_dsl.py +++ b/tests/starwars/test_dsl.py @@ -1,39 +1,52 @@ -from gql import dsl +import pytest + +from gql import Client +from gql.dsl import DSLSchema from .schema import characterInterface, humanType, queryType # We construct a Simple DSL objects for easy field referencing -class Query(object): - hero = queryType.get_fields()['hero'] - human = queryType.get_fields()['human'] +# class Query(object): +# hero = queryType.fields['hero'] +# human = queryType.fields['human'] + + +# class Character(object): +# id = characterInterface.fields['id'] +# name = characterInterface.fields['name'] +# friends = characterInterface.fields['friends'] +# appears_in = characterInterface.fields['appearsIn'] -class Character(object): - id = characterInterface.get_fields()['id'] - name = characterInterface.get_fields()['name'] - friends = characterInterface.get_fields()['friends'] - appears_in = characterInterface.get_fields()['appearsIn'] +# class Human(object): +# name = humanType.fields['name'] -class Human(object): - name = humanType.get_fields()['name'] +from .schema import StarWarsSchema -def test_hero_name_query(): +@pytest.fixture +def ds(): + client = Client(schema=StarWarsSchema) + ds = DSLSchema(client) + return ds + + +def test_hero_name_query(ds): query = ''' hero { name } '''.strip() - query_dsl = dsl.field(Query.hero).get( - Character.name + query_dsl = ds.Query.hero( + ds.Character.name ) assert query == str(query_dsl) -def test_hero_name_and_friends_query(): +def test_hero_name_and_friends_query(ds): query = ''' hero { id @@ -43,17 +56,17 @@ def test_hero_name_and_friends_query(): } } '''.strip() - query_dsl = dsl.field(Query.hero).get( - Character.id, - Character.name, - dsl.field(Character.friends).get( - Character.name, + query_dsl = ds.Query.hero( + ds.Character.id, + ds.Character.name, + ds.Character.friends( + ds.Character.name, ) ) assert query == str(query_dsl) -def test_nested_query(): +def test_nested_query(ds): query = ''' hero { name @@ -66,27 +79,27 @@ def test_nested_query(): } } '''.strip() - query_dsl = dsl.field(Query.hero).get( - Character.name, - dsl.field(Character.friends).get( - Character.name, - Character.appears_in, - dsl.field(Character.friends).get( - Character.name + query_dsl = ds.Query.hero( + ds.Character.name, + ds.Character.friends( + ds.Character.name, + ds.Character.appears_in, + ds.Character.friends( + ds.Character.name ) ) ) assert query == str(query_dsl) -def test_fetch_luke_query(): +def test_fetch_luke_query(ds): query = ''' human(id: "1000") { name } '''.strip() - query_dsl = dsl.field(Query.human, id="1000").get( - Human.name, + query_dsl = ds.Query.human.args(id="1000").get( + ds.Human.name, ) assert query == str(query_dsl) @@ -153,19 +166,14 @@ def test_fetch_luke_query(): # assert result.data == expected -def test_fetch_luke_aliased(): +def test_fetch_luke_aliased(ds): query = ''' luke: human(id: "1000") { name } '''.strip() - expected = { - 'luke': { - 'name': 'Luke Skywalker', - } - } - query_dsl = dsl.field(Query.human, id=1000).alias('luke').get( - Character.name, + query_dsl = ds.Query.human.args(id=1000).alias('luke').get( + ds.Character.name, ) assert query == str(query_dsl)