Skip to content

Commit

Permalink
sdk: python: Fix inconsistencies in passing IDs to fields
Browse files Browse the repository at this point in the history
Signed-off-by: Helder Correia <174525+helderco@users.noreply.github.com>
  • Loading branch information
helderco committed Dec 14, 2022
1 parent 91accd9 commit d6107e7
Show file tree
Hide file tree
Showing 3 changed files with 65 additions and 42 deletions.
30 changes: 14 additions & 16 deletions sdk/python/src/dagger/api/gen.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

30 changes: 14 additions & 16 deletions sdk/python/src/dagger/api/gen_sync.py

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

47 changes: 37 additions & 10 deletions sdk/python/src/dagger/codegen.py
Expand Up @@ -10,7 +10,16 @@
from itertools import chain, groupby
from keyword import iskeyword
from operator import attrgetter
from typing import Any, ClassVar, Generic, Iterator, Protocol, TypeGuard, TypeVar
from typing import (
Any,
ClassVar,
Generic,
Iterator,
Protocol,
TypeAlias,
TypeGuard,
TypeVar,
)

from attrs import Factory, define
from graphql import (
Expand Down Expand Up @@ -48,6 +57,11 @@
wrap_indent = partial(wrap, initial_indent=" " * 4, subsequent_indent=" " * 4)


IDName: TypeAlias = str
TypeName: TypeAlias = str
IDMap: TypeAlias = dict[IDName, TypeName]


class Scalars(Enum):
ID = str
Int = int
Expand Down Expand Up @@ -94,7 +108,7 @@ def generate(schema: GraphQLSchema, sync: bool = False) -> Iterator[str]:

# collect object types for all id return types
# used to replace custom scalars by objects in inputs
id_map: dict[str, str] = {}
id_map: IDMap = {}
for type_name, t in schema.type_map.items():
if is_wrapping_type(t):
t = t.of_type
Expand Down Expand Up @@ -189,7 +203,7 @@ def format_name(s: str) -> str:
return s


def format_input_type(t: GraphQLInputType, id_map: dict[str, str]) -> str:
def format_input_type(t: GraphQLInputType, id_map: IDMap) -> str:
"""This may be used in an input object field or an object field parameter."""

if is_required_type(t):
Expand Down Expand Up @@ -245,17 +259,30 @@ def __init__(
self,
name: str,
graphql: GraphQLInputField | GraphQLArgument,
id_map: dict[str, str],
id_map: IDMap,
parent: "_ObjectField | None" = None,
) -> None:
self.graphql_name = name
self.graphql = graphql

self.name = format_name(name)
named_type = get_named_type(graphql.type)

# On object type fields, don't replace ID scalar with object
# only if field name is `id` and the corresponding type is different
# from the output type (e.g., `file(id: FileID) -> File`, but also
# `with_rootfs(id: Directory) -> Container`).
if (
name == "id"
and is_custom_scalar_type(graphql.type)
and named_type.name in id_map
and parent
and get_named_type(parent.graphql.type).name == id_map[named_type.name]
):
id_map = {}

self.type = format_input_type(graphql.type, id_map)
if name == "id" and is_custom_scalar_type(graphql.type):
self.type = f"{get_named_type(graphql.type)} | {self.type}"
self.description = graphql.description

self.has_default = graphql.default_value is not Undefined
self.default_value = graphql.default_value

Expand Down Expand Up @@ -301,7 +328,7 @@ def __init__(
self,
name: str,
field: GraphQLField,
id_map: dict[str, str],
id_map: IDMap,
sync: bool,
) -> None:
self.graphql_name = name
Expand All @@ -310,7 +337,7 @@ def __init__(

self.name = format_name(name)
self.args = sorted(
(_InputField(*args, id_map) for args in field.args.items()),
(_InputField(*args, id_map, parent=self) for args in field.args.items()),
key=attrgetter("has_default"),
)
self.description = field.description
Expand Down Expand Up @@ -422,7 +449,7 @@ class Handler(ABC, Generic[_H]):
sync: bool = False
"""Sync or async."""

id_map: dict[str, str] = Factory(dict)
id_map: IDMap = Factory(dict)
"""Map to convert ids (custom scalars) to corresponding types."""

predicate: ClassVar[Predicate] = staticmethod(lambda _: True)
Expand Down

0 comments on commit d6107e7

Please sign in to comment.