Skip to content

Commit

Permalink
Fix dsl root operation types custom names (#320)
Browse files Browse the repository at this point in the history
  • Loading branch information
leszekhanusz committed Apr 11, 2022
1 parent ea96294 commit 0926ed6
Show file tree
Hide file tree
Showing 2 changed files with 69 additions and 4 deletions.
37 changes: 33 additions & 4 deletions gql/dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -297,7 +297,7 @@ def __getattr__(self, name: str) -> "DSLType":

assert isinstance(type_def, (GraphQLObjectType, GraphQLInterfaceType))

return DSLType(type_def)
return DSLType(type_def, self)


class DSLSelector(ABC):
Expand Down Expand Up @@ -454,7 +454,27 @@ def is_valid_field(self, field: "DSLSelectable") -> bool:
return operation_name != "SUBSCRIPTION"

elif isinstance(field, DSLField):
return field.parent_type.name.upper() == operation_name

assert field.dsl_type is not None

schema = field.dsl_type._dsl_schema._schema

root_type = None

if operation_name == "QUERY":
root_type = schema.query_type
elif operation_name == "MUTATION":
root_type = schema.mutation_type
elif operation_name == "SUBSCRIPTION":
root_type = schema.subscription_type

if root_type is None:
log.error(
f"Root type of type {operation_name} not found in the schema!"
)
return False

return field.parent_type.name == root_type.name

return False

Expand Down Expand Up @@ -585,16 +605,22 @@ class DSLType:
instances of :class:`DSLField`
"""

def __init__(self, graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType]):
def __init__(
self,
graphql_type: Union[GraphQLObjectType, GraphQLInterfaceType],
dsl_schema: DSLSchema,
):
"""Initialize the DSLType with the GraphQL type.
.. warning::
Don't instantiate this class yourself.
Use attributes of the :class:`DSLSchema` instead.
:param graphql_type: the GraphQL type definition from the schema
:param dsl_schema: reference to the DSLSchema which created this type
"""
self._type: Union[GraphQLObjectType, GraphQLInterfaceType] = graphql_type
self._dsl_schema = dsl_schema
log.debug(f"Creating {self!r})")

def __getattr__(self, name: str) -> "DSLField":
Expand All @@ -611,7 +637,7 @@ def __getattr__(self, name: str) -> "DSLField":
f"Field {name} does not exist in type {self._type.name}."
)

return DSLField(formatted_name, self._type, field)
return DSLField(formatted_name, self._type, field, self)

def __repr__(self) -> str:
return f"<{self.__class__.__name__} {self._type!r}>"
Expand Down Expand Up @@ -763,6 +789,7 @@ def __init__(
name: str,
parent_type: Union[GraphQLObjectType, GraphQLInterfaceType],
field: GraphQLField,
dsl_type: Optional[DSLType] = None,
):
"""Initialize the DSLField.
Expand All @@ -774,10 +801,12 @@ def __init__(
:param parent_type: the GraphQL type definition from the schema of the
parent type of the field
:param field: the GraphQL field definition from the schema
:param dsl_type: reference of the DSLType instance which created this field
"""
self.parent_type = parent_type
self.field = field
self.ast_field = FieldNode(name=NameNode(value=name), arguments=())
self.dsl_type = dsl_type

log.debug(f"Creating {self!r}")

Expand Down
36 changes: 36 additions & 0 deletions tests/starwars/test_dsl.py
Original file line number Diff line number Diff line change
Expand Up @@ -728,6 +728,42 @@ def test_dsl_query_all_fields_should_correspond_to_the_root_type(ds):
)


def test_dsl_root_type_not_default():

from graphql import parse, build_ast_schema

schema_str = """
schema {
query: QueryNotDefault
}
type QueryNotDefault {
version: String
}
"""

type_def_ast = parse(schema_str)
schema = build_ast_schema(type_def_ast)

ds = DSLSchema(schema)

query = dsl_gql(DSLQuery(ds.QueryNotDefault.version))

expected_query = """
{
version
}
"""
assert print_ast(query) == expected_query.strip()

with pytest.raises(GraphQLError) as excinfo:
DSLSubscription(ds.QueryNotDefault.version)

assert (
"Invalid field for <DSLSubscription>: <DSLField QueryNotDefault::version>"
) in str(excinfo.value)


def test_dsl_gql_all_arguments_should_be_operations_or_fragments():
with pytest.raises(
TypeError, match="Operations should be instances of DSLExecutable "
Expand Down

0 comments on commit 0926ed6

Please sign in to comment.