diff --git a/CHANGELOG.md b/CHANGELOG.md index c5ed968..f61db31 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -4,6 +4,7 @@ - Added `ClientForwardRefsPlugin` to standard plugins. - Re-added `model_rebuild` calls for input types with forward references. +- Fixed fragments on interfaces being omitted from generated client. ## 0.13.0 (2024-03-4) diff --git a/ariadne_codegen/client_generators/result_types.py b/ariadne_codegen/client_generators/result_types.py index b43ac4a..24326d1 100644 --- a/ariadne_codegen/client_generators/result_types.py +++ b/ariadne_codegen/client_generators/result_types.py @@ -326,9 +326,12 @@ def _resolve_selection_set( fields.extend(sub_fields) fragments = fragments.union(sub_fragments) elif isinstance(selection, InlineFragmentNode): - if selection.type_condition.name.value == root_type: + root_type_value = self._get_inline_fragment_root_type( + selection.type_condition.name.value, root_type + ) + if root_type_value: sub_fields, sub_fragments = self._resolve_selection_set( - selection.selection_set, root_type + selection.selection_set, root_type_value ) fields.extend(sub_fields) fragments = fragments.union(sub_fragments) @@ -337,6 +340,23 @@ def _resolve_selection_set( ) return fields, fragments + def _get_inline_fragment_root_type( + self, selection_value: str, root_type: str + ) -> Optional[str]: + type_ = self.schema.type_map.get(root_type) + if not type_: + return None + + if isinstance(type_, GraphQLObjectType) and selection_value in { + interface.name for interface in type_.interfaces + }: + return selection_value + + if selection_value == root_type: + return root_type + + return None + def _unpack_fragment( self, fragment_def: FragmentDefinitionNode, diff --git a/tests/main/clients/interface_as_fragment/expected_client/__init__.py b/tests/main/clients/interface_as_fragment/expected_client/__init__.py new file mode 100644 index 0000000..8fba877 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/__init__.py @@ -0,0 +1,35 @@ +from .async_base_client import AsyncBaseClient +from .base_model import BaseModel, Upload +from .client import Client +from .exceptions import ( + GraphQLClientError, + GraphQLClientGraphQLError, + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidResponseError, +) +from .fragments import Item, ItemError +from .my_mutation import ( + MyMutation, + MyMutationChangeItem, + MyMutationChangeItemContacts, + MyMutationChangeItemErrorsItemServiceInternalError, +) + +__all__ = [ + "AsyncBaseClient", + "BaseModel", + "Client", + "GraphQLClientError", + "GraphQLClientGraphQLError", + "GraphQLClientGraphQLMultiError", + "GraphQLClientHttpError", + "GraphQLClientInvalidResponseError", + "Item", + "ItemError", + "MyMutation", + "MyMutationChangeItem", + "MyMutationChangeItemContacts", + "MyMutationChangeItemErrorsItemServiceInternalError", + "Upload", +] diff --git a/tests/main/clients/interface_as_fragment/expected_client/async_base_client.py b/tests/main/clients/interface_as_fragment/expected_client/async_base_client.py new file mode 100644 index 0000000..5358ced --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/async_base_client.py @@ -0,0 +1,370 @@ +import enum +import json +from typing import IO, Any, AsyncIterator, Dict, List, Optional, Tuple, TypeVar, cast +from uuid import uuid4 + +import httpx +from pydantic import BaseModel +from pydantic_core import to_jsonable_python + +from .base_model import UNSET, Upload +from .exceptions import ( + GraphQLClientGraphQLMultiError, + GraphQLClientHttpError, + GraphQLClientInvalidMessageFormat, + GraphQLClientInvalidResponseError, +) + +try: + from websockets.client import ( # type: ignore[import-not-found,unused-ignore] + WebSocketClientProtocol, + connect as ws_connect, + ) + from websockets.typing import ( # type: ignore[import-not-found,unused-ignore] + Data, + Origin, + Subprotocol, + ) +except ImportError: + from contextlib import asynccontextmanager + + @asynccontextmanager # type: ignore + async def ws_connect(*args, **kwargs): # pylint: disable=unused-argument + raise NotImplementedError("Subscriptions require 'websockets' package.") + yield # pylint: disable=unreachable + + WebSocketClientProtocol = Any # type: ignore[misc,assignment,unused-ignore] + Data = Any # type: ignore[misc,assignment,unused-ignore] + Origin = Any # type: ignore[misc,assignment,unused-ignore] + + def Subprotocol(*args, **kwargs): # type: ignore # pylint: disable=invalid-name + raise NotImplementedError("Subscriptions require 'websockets' package.") + + +Self = TypeVar("Self", bound="AsyncBaseClient") + +GRAPHQL_TRANSPORT_WS = "graphql-transport-ws" + + +class GraphQLTransportWSMessageType(str, enum.Enum): + CONNECTION_INIT = "connection_init" + CONNECTION_ACK = "connection_ack" + PING = "ping" + PONG = "pong" + SUBSCRIBE = "subscribe" + NEXT = "next" + ERROR = "error" + COMPLETE = "complete" + + +class AsyncBaseClient: + def __init__( + self, + url: str = "", + headers: Optional[Dict[str, str]] = None, + http_client: Optional[httpx.AsyncClient] = None, + ws_url: str = "", + ws_headers: Optional[Dict[str, Any]] = None, + ws_origin: Optional[str] = None, + ws_connection_init_payload: Optional[Dict[str, Any]] = None, + ) -> None: + self.url = url + self.headers = headers + self.http_client = ( + http_client if http_client else httpx.AsyncClient(headers=headers) + ) + + self.ws_url = ws_url + self.ws_headers = ws_headers or {} + self.ws_origin = Origin(ws_origin) if ws_origin else None + self.ws_connection_init_payload = ws_connection_init_payload + + async def __aenter__(self: Self) -> Self: + return self + + async def __aexit__( + self, + exc_type: object, + exc_val: object, + exc_tb: object, + ) -> None: + await self.http_client.aclose() + + async def execute( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> httpx.Response: + processed_variables, files, files_map = self._process_variables(variables) + + if files and files_map: + return await self._execute_multipart( + query=query, + operation_name=operation_name, + variables=processed_variables, + files=files, + files_map=files_map, + **kwargs, + ) + + return await self._execute_json( + query=query, + operation_name=operation_name, + variables=processed_variables, + **kwargs, + ) + + def get_data(self, response: httpx.Response) -> Dict[str, Any]: + if not response.is_success: + raise GraphQLClientHttpError( + status_code=response.status_code, response=response + ) + + try: + response_json = response.json() + except ValueError as exc: + raise GraphQLClientInvalidResponseError(response=response) from exc + + if (not isinstance(response_json, dict)) or ( + "data" not in response_json and "errors" not in response_json + ): + raise GraphQLClientInvalidResponseError(response=response) + + data = response_json.get("data") + errors = response_json.get("errors") + + if errors: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=errors, data=data + ) + + return cast(Dict[str, Any], data) + + async def execute_ws( + self, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + **kwargs: Any, + ) -> AsyncIterator[Dict[str, Any]]: + headers = self.ws_headers.copy() + headers.update(kwargs.get("extra_headers", {})) + + merged_kwargs: Dict[str, Any] = {"origin": self.ws_origin} + merged_kwargs.update(kwargs) + merged_kwargs["extra_headers"] = headers + + operation_id = str(uuid4()) + async with ws_connect( + self.ws_url, + subprotocols=[Subprotocol(GRAPHQL_TRANSPORT_WS)], + **merged_kwargs, + ) as websocket: + await self._send_connection_init(websocket) + # wait for connection_ack from server + await self._handle_ws_message( + await websocket.recv(), + websocket, + expected_type=GraphQLTransportWSMessageType.CONNECTION_ACK, + ) + await self._send_subscribe( + websocket, + operation_id=operation_id, + query=query, + operation_name=operation_name, + variables=variables, + ) + + async for message in websocket: + data = await self._handle_ws_message(message, websocket) + if data: + yield data + + def _process_variables( + self, variables: Optional[Dict[str, Any]] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + if not variables: + return {}, {}, {} + + serializable_variables = self._convert_dict_to_json_serializable(variables) + return self._get_files_from_variables(serializable_variables) + + def _convert_dict_to_json_serializable( + self, dict_: Dict[str, Any] + ) -> Dict[str, Any]: + return { + key: self._convert_value(value) + for key, value in dict_.items() + if value is not UNSET + } + + def _convert_value(self, value: Any) -> Any: + if isinstance(value, BaseModel): + return value.model_dump(by_alias=True, exclude_unset=True) + if isinstance(value, list): + return [self._convert_value(item) for item in value] + return value + + def _get_files_from_variables( + self, variables: Dict[str, Any] + ) -> Tuple[ + Dict[str, Any], Dict[str, Tuple[str, IO[bytes], str]], Dict[str, List[str]] + ]: + files_map: Dict[str, List[str]] = {} + files_list: List[Upload] = [] + + def separate_files(path: str, obj: Any) -> Any: + if isinstance(obj, list): + nulled_list = [] + for index, value in enumerate(obj): + value = separate_files(f"{path}.{index}", value) + nulled_list.append(value) + return nulled_list + + if isinstance(obj, dict): + nulled_dict = {} + for key, value in obj.items(): + value = separate_files(f"{path}.{key}", value) + nulled_dict[key] = value + return nulled_dict + + if isinstance(obj, Upload): + if obj in files_list: + file_index = files_list.index(obj) + files_map[str(file_index)].append(path) + else: + file_index = len(files_list) + files_list.append(obj) + files_map[str(file_index)] = [path] + return None + + return obj + + nulled_variables = separate_files("variables", variables) + files: Dict[str, Tuple[str, IO[bytes], str]] = { + str(i): (file_.filename, cast(IO[bytes], file_.content), file_.content_type) + for i, file_ in enumerate(files_list) + } + return nulled_variables, files, files_map + + async def _execute_multipart( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + files: Dict[str, Tuple[str, IO[bytes], str]], + files_map: Dict[str, List[str]], + **kwargs: Any, + ) -> httpx.Response: + data = { + "operations": json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + "map": json.dumps(files_map, default=to_jsonable_python), + } + + return await self.http_client.post( + url=self.url, data=data, files=files, **kwargs + ) + + async def _execute_json( + self, + query: str, + operation_name: Optional[str], + variables: Dict[str, Any], + **kwargs: Any, + ) -> httpx.Response: + headers: Dict[str, str] = {"Content-Type": "application/json"} + headers.update(kwargs.get("headers", {})) + + merged_kwargs: Dict[str, Any] = kwargs.copy() + merged_kwargs["headers"] = headers + + return await self.http_client.post( + url=self.url, + content=json.dumps( + { + "query": query, + "operationName": operation_name, + "variables": variables, + }, + default=to_jsonable_python, + ), + **merged_kwargs, + ) + + async def _send_connection_init(self, websocket: WebSocketClientProtocol) -> None: + payload: Dict[str, Any] = { + "type": GraphQLTransportWSMessageType.CONNECTION_INIT.value + } + if self.ws_connection_init_payload: + payload["payload"] = self.ws_connection_init_payload + await websocket.send(json.dumps(payload)) + + async def _send_subscribe( + self, + websocket: WebSocketClientProtocol, + operation_id: str, + query: str, + operation_name: Optional[str] = None, + variables: Optional[Dict[str, Any]] = None, + ) -> None: + payload: Dict[str, Any] = { + "id": operation_id, + "type": GraphQLTransportWSMessageType.SUBSCRIBE.value, + "payload": {"query": query, "operationName": operation_name}, + } + if variables: + payload["payload"]["variables"] = self._convert_dict_to_json_serializable( + variables + ) + await websocket.send(json.dumps(payload)) + + async def _handle_ws_message( + self, + message: Data, + websocket: WebSocketClientProtocol, + expected_type: Optional[GraphQLTransportWSMessageType] = None, + ) -> Optional[Dict[str, Any]]: + try: + message_dict = json.loads(message) + except json.JSONDecodeError as exc: + raise GraphQLClientInvalidMessageFormat(message=message) from exc + + type_ = message_dict.get("type") + payload = message_dict.get("payload", {}) + + if not type_ or type_ not in {t.value for t in GraphQLTransportWSMessageType}: + raise GraphQLClientInvalidMessageFormat(message=message) + + if expected_type and expected_type != type_: + raise GraphQLClientInvalidMessageFormat( + f"Invalid message received. Expected: {expected_type.value}" + ) + + if type_ == GraphQLTransportWSMessageType.NEXT: + if "data" not in payload: + raise GraphQLClientInvalidMessageFormat(message=message) + return cast(Dict[str, Any], payload["data"]) + + if type_ == GraphQLTransportWSMessageType.COMPLETE: + await websocket.close() + elif type_ == GraphQLTransportWSMessageType.PING: + await websocket.send( + json.dumps({"type": GraphQLTransportWSMessageType.PONG.value}) + ) + elif type_ == GraphQLTransportWSMessageType.ERROR: + raise GraphQLClientGraphQLMultiError.from_errors_dicts( + errors_dicts=payload, data=message_dict + ) + + return None diff --git a/tests/main/clients/interface_as_fragment/expected_client/base_model.py b/tests/main/clients/interface_as_fragment/expected_client/base_model.py new file mode 100644 index 0000000..ccde397 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/base_model.py @@ -0,0 +1,27 @@ +from io import IOBase + +from pydantic import BaseModel as PydanticBaseModel, ConfigDict + + +class UnsetType: + def __bool__(self) -> bool: + return False + + +UNSET = UnsetType() + + +class BaseModel(PydanticBaseModel): + model_config = ConfigDict( + populate_by_name=True, + validate_assignment=True, + arbitrary_types_allowed=True, + protected_namespaces=(), + ) + + +class Upload: + def __init__(self, filename: str, content: IOBase, content_type: str): + self.filename = filename + self.content = content + self.content_type = content_type diff --git a/tests/main/clients/interface_as_fragment/expected_client/client.py b/tests/main/clients/interface_as_fragment/expected_client/client.py new file mode 100644 index 0000000..23d8298 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/client.py @@ -0,0 +1,44 @@ +from typing import Any, Dict + +from .async_base_client import AsyncBaseClient +from .my_mutation import MyMutation + + +def gql(q: str) -> str: + return q + + +class Client(AsyncBaseClient): + async def my_mutation(self, id: str, **kwargs: Any) -> MyMutation: + query = gql( + """ + mutation my_mutation($id: ID!) { + change_item(id: $id) { + contacts { + ...Item + } + errors { + __typename + ... on ItemError { + ...ItemError + } + } + } + } + + fragment Item on Item { + id + } + + fragment ItemError on ItemError { + __typename + message + } + """ + ) + variables: Dict[str, object] = {"id": id} + response = await self.execute( + query=query, operation_name="my_mutation", variables=variables, **kwargs + ) + data = self.get_data(response) + return MyMutation.model_validate(data) diff --git a/tests/main/clients/interface_as_fragment/expected_client/enums.py b/tests/main/clients/interface_as_fragment/expected_client/enums.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/main/clients/interface_as_fragment/expected_client/exceptions.py b/tests/main/clients/interface_as_fragment/expected_client/exceptions.py new file mode 100644 index 0000000..b34acfe --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/exceptions.py @@ -0,0 +1,83 @@ +from typing import Any, Dict, List, Optional, Union + +import httpx + + +class GraphQLClientError(Exception): + """Base exception.""" + + +class GraphQLClientHttpError(GraphQLClientError): + def __init__(self, status_code: int, response: httpx.Response) -> None: + self.status_code = status_code + self.response = response + + def __str__(self) -> str: + return f"HTTP status code: {self.status_code}" + + +class GraphQLClientInvalidResponseError(GraphQLClientError): + def __init__(self, response: httpx.Response) -> None: + self.response = response + + def __str__(self) -> str: + return "Invalid response format." + + +class GraphQLClientGraphQLError(GraphQLClientError): + def __init__( + self, + message: str, + locations: Optional[List[Dict[str, int]]] = None, + path: Optional[List[str]] = None, + extensions: Optional[Dict[str, object]] = None, + orginal: Optional[Dict[str, object]] = None, + ): + self.message = message + self.locations = locations + self.path = path + self.extensions = extensions + self.orginal = orginal + + def __str__(self) -> str: + return self.message + + @classmethod + def from_dict(cls, error: Dict[str, Any]) -> "GraphQLClientGraphQLError": + return cls( + message=error["message"], + locations=error.get("locations"), + path=error.get("path"), + extensions=error.get("extensions"), + orginal=error, + ) + + +class GraphQLClientGraphQLMultiError(GraphQLClientError): + def __init__( + self, + errors: List[GraphQLClientGraphQLError], + data: Optional[Dict[str, Any]] = None, + ): + self.errors = errors + self.data = data + + def __str__(self) -> str: + return "; ".join(str(e) for e in self.errors) + + @classmethod + def from_errors_dicts( + cls, errors_dicts: List[Dict[str, Any]], data: Optional[Dict[str, Any]] = None + ) -> "GraphQLClientGraphQLMultiError": + return cls( + errors=[GraphQLClientGraphQLError.from_dict(e) for e in errors_dicts], + data=data, + ) + + +class GraphQLClientInvalidMessageFormat(GraphQLClientError): + def __init__(self, message: Union[str, bytes]) -> None: + self.message = message + + def __str__(self) -> str: + return "Invalid message format." diff --git a/tests/main/clients/interface_as_fragment/expected_client/fragments.py b/tests/main/clients/interface_as_fragment/expected_client/fragments.py new file mode 100644 index 0000000..a888e86 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/fragments.py @@ -0,0 +1,18 @@ +from typing import Optional + +from pydantic import Field + +from .base_model import BaseModel + + +class Item(BaseModel): + id: Optional[str] + + +class ItemError(BaseModel): + typename__: str = Field(alias="__typename") + message: str + + +Item.model_rebuild() +ItemError.model_rebuild() diff --git a/tests/main/clients/interface_as_fragment/expected_client/input_types.py b/tests/main/clients/interface_as_fragment/expected_client/input_types.py new file mode 100644 index 0000000..e69de29 diff --git a/tests/main/clients/interface_as_fragment/expected_client/my_mutation.py b/tests/main/clients/interface_as_fragment/expected_client/my_mutation.py new file mode 100644 index 0000000..c36c7f4 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/expected_client/my_mutation.py @@ -0,0 +1,34 @@ +from typing import Annotated, List, Literal, Optional, Union + +from pydantic import Field + +from .base_model import BaseModel +from .fragments import Item, ItemError + + +class MyMutation(BaseModel): + change_item: Optional["MyMutationChangeItem"] + + +class MyMutationChangeItem(BaseModel): + contacts: Optional[List["MyMutationChangeItemContacts"]] + errors: Optional[ + List[ + Annotated[ + Union["MyMutationChangeItemErrorsItemServiceInternalError",], + Field(discriminator="typename__"), + ] + ] + ] + + +class MyMutationChangeItemContacts(Item): + pass + + +class MyMutationChangeItemErrorsItemServiceInternalError(ItemError): + typename__: Literal["ItemServiceInternalError"] = Field(alias="__typename") + + +MyMutation.model_rebuild() +MyMutationChangeItem.model_rebuild() diff --git a/tests/main/clients/interface_as_fragment/pyproject.toml b/tests/main/clients/interface_as_fragment/pyproject.toml new file mode 100644 index 0000000..360d591 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/pyproject.toml @@ -0,0 +1,5 @@ +[tool.ariadne-codegen] +schema_path = "schema.graphql" +queries_path = "queries.graphql" +target_package_name = "interface_as_fragment" +include_comments = "none" diff --git a/tests/main/clients/interface_as_fragment/queries.graphql b/tests/main/clients/interface_as_fragment/queries.graphql new file mode 100644 index 0000000..4461728 --- /dev/null +++ b/tests/main/clients/interface_as_fragment/queries.graphql @@ -0,0 +1,21 @@ +fragment Item on Item { + id +} + +fragment ItemError on ItemError { + __typename + message +} + +mutation my_mutation($id: ID!) { + change_item(id: $id) { + contacts { + ...Item + } + errors { + ... on ItemError { + ...ItemError + } + } + } +} diff --git a/tests/main/clients/interface_as_fragment/schema.graphql b/tests/main/clients/interface_as_fragment/schema.graphql new file mode 100644 index 0000000..823e6ca --- /dev/null +++ b/tests/main/clients/interface_as_fragment/schema.graphql @@ -0,0 +1,22 @@ +type Item { + id: ID +} + +type ItemResult { + contacts: [Item!] + errors: [ItemServiceError!] +} + +interface ItemError { + message: String! +} + +union ItemServiceError = ItemServiceInternalError + +type ItemServiceInternalError implements ItemError { + message: String! +} + +type Mutation { + change_item(id: ID!): ItemResult +} diff --git a/tests/main/test_main.py b/tests/main/test_main.py index 8dc755a..9b94f1a 100644 --- a/tests/main/test_main.py +++ b/tests/main/test_main.py @@ -186,6 +186,17 @@ def test_main_shows_version(): "client_only_used_inputs_and_enums", CLIENTS_PATH / "only_used_inputs_and_enums" / "expected_client", ), + ( + ( + CLIENTS_PATH / "interface_as_fragment" / "pyproject.toml", + ( + CLIENTS_PATH / "interface_as_fragment" / "queries.graphql", + CLIENTS_PATH / "interface_as_fragment" / "schema.graphql", + ), + ), + "interface_as_fragment", + CLIENTS_PATH / "interface_as_fragment" / "expected_client", + ), ], indirect=["project_dir"], )