diff --git a/README.md b/README.md index f561b46..98ffb61 100644 --- a/README.md +++ b/README.md @@ -42,6 +42,8 @@ - [Get Toolset by Id](#get-toolset-by-id) - [Resource Permissions](#resource-permissions) - [Grant Permissions](#grant-permissions) + - [Client Channel](#client-channel) + - [Sign In to Toolsets](#sign-in-to-toolsets) - [Client Pool](#client-pool) - [Synchronous Client Pool](#synchronous-client-pool) - [Asynchronous Client Pool](#asynchronous-client-pool) @@ -854,6 +856,55 @@ await async_client.resource_permissions.grant( The method returns `None` on success and raises `DialException` on HTTP error. +### Client Channel + +DIAL Core's [client channel API](https://dialx.ai/universal_chat_api.yaml) lets a deployment ask an interactive client (e.g. the chat UI) to take some action and report the result back. The channel id is propagated to the deployment via the `X-DIAL-CLIENT-CHANNEL-ID` forwarded header on the inbound request. + +#### Sign In to Toolsets + +Use `client_channel.signin_toolsets()` to request interactive sign-in for one or more toolsets on the active client channel. The method returns a `dict[str, SigninResult]` mapping each input toolset id to its outcome — responses are correlated by the client, so the caller never has to deal with the underlying JSON-RPC ids. + +```python +from aidial_client import SigninResult + +# Sync +results = client.client_channel.signin_toolsets( + channel_id="", + toolset_ids=[ + "toolsets/public/toolset-a", + "toolsets/public/toolset-b", + ], + timeout=120.0, +) + +# Async +results = await async_client.client_channel.signin_toolsets( + channel_id="", + toolset_ids=["toolsets/public/my-toolset"], +) +``` + +Each value is a `SigninResult` enum: + +```python +{ + "toolsets/public/toolset-a": SigninResult.SUCCESS, + "toolsets/public/toolset-b": SigninResult.DENIED, +} +``` + +- `SigninResult.SUCCESS` — the user signed in. +- `SigninResult.DENIED` — the user declined. +- `SigninResult.ERROR` — the server returned a JSON-RPC error, or the response was missing/unrecognized. + +Arguments: + +- `channel_id` — required; the channel id received via the `X-DIAL-CLIENT-CHANNEL-ID` header on the inbound request. +- `toolset_ids` — sequence of toolset ids to request sign-in for; an empty sequence returns `{}` without contacting the server. +- `timeout` — optional `float` seconds or `httpx.Timeout`; defaults to the client-wide timeout. Useful for interactive flows where the user may take a while to respond. + +Raises `DialException` on HTTP errors (e.g. unauthorized, missing channel), transport failures (timeouts, network errors), or if the SSE stream closes without a response event. + ### Client Pool When you need to create multiple DIAL clients and wish to enhance performance by reusing the HTTP connection for the same DIAL instance, consider using synchronous and asynchronous **client pools**. diff --git a/aidial_client/__init__.py b/aidial_client/__init__.py index 397016c..edf195f 100644 --- a/aidial_client/__init__.py +++ b/aidial_client/__init__.py @@ -9,6 +9,7 @@ ParsingDataError, ResourceNotFoundError, ) +from aidial_client.types.client_channel import SigninResult from aidial_client.types.model import ModelInfo, ModelLimits, ModelPricing from aidial_client.types.toolset import ToolsetInfo @@ -30,4 +31,5 @@ "ModelInfo", "ModelPricing", "ModelLimits", + "SigninResult", ] diff --git a/aidial_client/_client.py b/aidial_client/_client.py index 3486c7e..5c4907d 100644 --- a/aidial_client/_client.py +++ b/aidial_client/_client.py @@ -119,6 +119,9 @@ def _init_resources(self) -> None: self.resource_permissions = resources.ResourcePermissions( http_client=self._http_client ) + self.client_channel = resources.ClientChannel( + http_client=self._http_client + ) def _create_http_client(self) -> SyncHTTPClient: return SyncHTTPClient( @@ -207,6 +210,9 @@ def _init_resources(self) -> None: self.resource_permissions = resources.AsyncResourcePermissions( http_client=self._http_client ) + self.client_channel = resources.AsyncClientChannel( + http_client=self._http_client + ) def _create_http_client(self) -> AsyncHTTPClient: return AsyncHTTPClient( diff --git a/aidial_client/_http_client/_async.py b/aidial_client/_http_client/_async.py index 8417940..5dcf37c 100644 --- a/aidial_client/_http_client/_async.py +++ b/aidial_client/_http_client/_async.py @@ -1,12 +1,23 @@ import asyncio +from contextlib import asynccontextmanager from http import HTTPStatus -from typing import Callable, Dict, Optional, Type +from typing import ( + Any, + AsyncIterator, + Callable, + Dict, + Mapping, + Optional, + Type, + Union, +) import httpx from aidial_client._auth import AsyncAuthValue, aget_combined_auth_headers from aidial_client._exception import DialException from aidial_client._http_client._base import BaseHTTPClient +from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven from aidial_client._internal_types._generic import ResponseT from aidial_client._internal_types._http_request import FinalRequestOptions from aidial_client._log import logger @@ -108,3 +119,54 @@ async def request( raise raised_error from err return process_block_response(cast_to=cast_to, response=response) + + @asynccontextmanager + async def stream_sse( + self, + *, + method: str, + url: str, + json_data: Any, + headers: Optional[Mapping[str, str]] = None, + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> AsyncIterator[httpx.Response]: + """Open an SSE streaming response. Yields the open httpx.Response. + + Auth headers are merged in. On non-2xx, reads the body and raises + a DialException; transport errors (timeouts, network failures) are + also wrapped so the caller always sees DialException. Retries are + not performed for streaming requests. + + ``timeout`` defaults to the client-wide timeout; pass an explicit + ``None`` (or ``httpx.Timeout(None)``) for no timeout. + """ + merged_headers = {**(await self.auth_headers()), **(headers or {})} + effective_timeout = ( + self._timeout if isinstance(timeout, NotGiven) else timeout + ) + try: + async with self._internal_http_client.stream( + method=method, + url=self._prepare_url(url), + headers=merged_headers, + json=json_data, + timeout=effective_timeout, + ) as response: + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + try: + await response.aread() + except httpx.HTTPError: + pass + raise self._make_dial_error_from_response( + err.response + ) from err + yield response + except httpx.TimeoutException as err: + raise DialException( + message="Request timed out", + status_code=HTTPStatus.REQUEST_TIMEOUT, + ) from err + except httpx.HTTPError as err: + raise DialException(message=f"Request failed: {err}") from err diff --git a/aidial_client/_http_client/_sse.py b/aidial_client/_http_client/_sse.py new file mode 100644 index 0000000..fa776bb --- /dev/null +++ b/aidial_client/_http_client/_sse.py @@ -0,0 +1,47 @@ +from typing import AsyncIterator, Iterator, List + +from aidial_client._log import logger + +_UNCOMMITTED_BUFFER_WARNING = ( + "Uncommitted data chunks in SSE stream " + "(stream ended without a terminating blank line); discarding." +) + + +def _strip_field(line: str, prefix: str) -> str: + """Strip a single leading U+0020 SPACE after the field colon, per the SSE spec.""" + value = line[len(prefix) :] + return value[1:] if value.startswith(" ") else value + + +def iter_data_events(lines: Iterator[str]) -> Iterator[str]: + """Yield the payload of each complete ``data:`` event from an SSE line stream. + + An event is complete when a blank line follows the ``data:`` line(s). Per + the SSE dispatch rule, a buffer that has not been terminated by a blank + line is discarded (we do NOT flush partial events at end of stream). + Comment lines (``:``) and other field names are ignored. + """ + buffer: List[str] = [] + for line in lines: + if line == "": + if buffer: + yield "\n".join(buffer) + buffer = [] + elif line.startswith("data:"): + buffer.append(_strip_field(line, "data:")) + if buffer: + logger.warning(_UNCOMMITTED_BUFFER_WARNING) + + +async def aiter_data_events(lines: AsyncIterator[str]) -> AsyncIterator[str]: + buffer: List[str] = [] + async for line in lines: + if line == "": + if buffer: + yield "\n".join(buffer) + buffer = [] + elif line.startswith("data:"): + buffer.append(_strip_field(line, "data:")) + if buffer: + logger.warning(_UNCOMMITTED_BUFFER_WARNING) diff --git a/aidial_client/_http_client/_sync.py b/aidial_client/_http_client/_sync.py index b330c12..583de97 100644 --- a/aidial_client/_http_client/_sync.py +++ b/aidial_client/_http_client/_sync.py @@ -1,12 +1,14 @@ import time +from contextlib import contextmanager from http import HTTPStatus -from typing import Callable, Dict, Optional, Type +from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Type, Union import httpx from aidial_client._auth import SyncAuthValue, get_combined_auth_headers from aidial_client._exception import DialException from aidial_client._http_client._base import BaseHTTPClient +from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven from aidial_client._internal_types._generic import ResponseT from aidial_client._internal_types._http_request import FinalRequestOptions from aidial_client._log import logger @@ -108,3 +110,54 @@ def request( raise raised_error from err return process_block_response(cast_to=cast_to, response=response) + + @contextmanager + def stream_sse( + self, + *, + method: str, + url: str, + json_data: Any, + headers: Optional[Mapping[str, str]] = None, + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> Iterator[httpx.Response]: + """Open an SSE streaming response. Yields the open httpx.Response. + + Auth headers are merged in. On non-2xx, reads the body and raises + a DialException; transport errors (timeouts, network failures) are + also wrapped so the caller always sees DialException. Retries are + not performed for streaming requests. + + ``timeout`` defaults to the client-wide timeout; pass an explicit + ``None`` (or ``httpx.Timeout(None)``) for no timeout. + """ + merged_headers = {**self.auth_headers(), **(headers or {})} + effective_timeout = ( + self._timeout if isinstance(timeout, NotGiven) else timeout + ) + try: + with self._internal_http_client.stream( + method=method, + url=self._prepare_url(url), + headers=merged_headers, + json=json_data, + timeout=effective_timeout, + ) as response: + try: + response.raise_for_status() + except httpx.HTTPStatusError as err: + try: + response.read() + except httpx.HTTPError: + pass + raise self._make_dial_error_from_response( + err.response + ) from err + yield response + except httpx.TimeoutException as err: + raise DialException( + message="Request timed out", + status_code=HTTPStatus.REQUEST_TIMEOUT, + ) from err + except httpx.HTTPError as err: + raise DialException(message=f"Request failed: {err}") from err diff --git a/aidial_client/_internal_types/_json_rpc.py b/aidial_client/_internal_types/_json_rpc.py new file mode 100644 index 0000000..90b9ffa --- /dev/null +++ b/aidial_client/_internal_types/_json_rpc.py @@ -0,0 +1,75 @@ +from typing import Any, Dict, List, Literal, Optional, Union + +from aidial_client._compatibility.pydantic_v1 import ( + BaseModel, + Extra, + Field, + root_validator, +) + + +class JsonRpcError(BaseModel): + code: int + message: str + data: Optional[Any] = None + + class Config: + extra = Extra.allow + + +class JsonRpcRequest(BaseModel): + jsonrpc: Literal["2.0"] = "2.0" + method: str + params: Optional[Union[List[Any], Dict[str, Any]]] = None + id: Optional[Union[int, str]] = None + + class Config: + smart_union = True + + +class JsonRpcResponse(BaseModel): + jsonrpc: Literal["2.0"] + result: Optional[Any] = None + error: Optional[JsonRpcError] = None + id: Optional[Union[int, str]] = Field(...) + + class Config: + smart_union = True + extra = Extra.allow + + @root_validator(pre=True) + def _validate_result_xor_error(cls, values): + """Per JSON-RPC 2.0 (https://www.jsonrpc.org/specification#response_object), + either ``result`` or ``error`` MUST be included (presence-wise — ``null`` + is a valid result value), and both MUST NOT be included. + """ + if not isinstance(values, dict): + return values + has_result = "result" in values + has_error = "error" in values + if has_result and has_error: + raise ValueError( + "JSON-RPC response must not contain both 'result' and 'error'" + ) + if not has_result and not has_error: + raise ValueError( + "JSON-RPC response must contain either 'result' or 'error'" + ) + return values + + +class JsonRpcResponses(BaseModel): + """Pydantic root model that accepts a single JSON-RPC response object or + a batch array, normalizing both to a list via the ``responses`` property. + """ + + __root__: Union[JsonRpcResponse, List[JsonRpcResponse]] + + class Config: + smart_union = True + + @property + def responses(self) -> List[JsonRpcResponse]: + if isinstance(self.__root__, list): + return self.__root__ + return [self.__root__] diff --git a/aidial_client/resources/__init__.py b/aidial_client/resources/__init__.py index 8e587ca..ed55b9a 100644 --- a/aidial_client/resources/__init__.py +++ b/aidial_client/resources/__init__.py @@ -1,3 +1,7 @@ +from aidial_client.resources.client_channel import ( + AsyncClientChannel, + ClientChannel, +) from aidial_client.resources.deployments import AsyncDeployments, Deployments from aidial_client.resources.metadata import AsyncMetadata, Metadata from aidial_client.resources.model import AsyncModel, Model @@ -34,4 +38,6 @@ "AsyncModel", "ResourcePermissions", "AsyncResourcePermissions", + "ClientChannel", + "AsyncClientChannel", ] diff --git a/aidial_client/resources/client_channel.py b/aidial_client/resources/client_channel.py new file mode 100644 index 0000000..9135262 --- /dev/null +++ b/aidial_client/resources/client_channel.py @@ -0,0 +1,218 @@ +from http import HTTPStatus +from typing import Any, List, Optional, Sequence, Union + +import httpx + +from aidial_client._compatibility.pydantic_v1 import ValidationError +from aidial_client._exception import ( + DialException, + InvalidRequestError, + ParsingDataError, +) +from aidial_client._http_client._sse import aiter_data_events, iter_data_events +from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven +from aidial_client._internal_types._json_rpc import ( + JsonRpcRequest, + JsonRpcResponse, + JsonRpcResponses, +) +from aidial_client.resources.base import AsyncResource, Resource +from aidial_client.types.client_channel import SigninResult + +_CLIENT_CHANNEL_HEADER = "X-DIAL-CLIENT-CHANNEL-ID" +_INTERACT_URL = "v1/ops/client-channel/interact" +_SIGNIN_METHOD = "toolset/signin" + + +def _normalize_toolset_ids(toolset_ids: Sequence[str]) -> List[str]: + """Validate ``toolset_ids`` and return a stable list. + + Catches three caller mistakes that would otherwise produce silent garbage: + a single string (str is itself a ``Sequence[str]``), a one-shot iterable + (consumed by the build step, leaving the mapping step with nothing), and + duplicate ids (the per-toolset result dict cannot represent two outcomes + for the same key). + """ + if isinstance(toolset_ids, str): + raise InvalidRequestError( + "toolset_ids must be a sequence of toolset ids, not a single str" + ) + materialized = list(toolset_ids) + if len(set(materialized)) != len(materialized): + raise InvalidRequestError("toolset_ids must not contain duplicates") + return materialized + + +def _serialize_requests(requests: Sequence[JsonRpcRequest]) -> Any: + """Serialize a sequence of JsonRpcRequest to the wire form. + + Always emits an array. DIAL Core accepts both an object and an array + body, but emitting a consistent shape avoids the "wire shape depends + on count" footgun and keeps the empty-input case safe. + """ + return [r.dict(exclude_none=True) for r in requests] + + +def _parse_responses(payload: str) -> List[JsonRpcResponse]: + try: + return JsonRpcResponses.parse_raw(payload).responses + except (ValidationError, ValueError) as err: + raise ParsingDataError( + message=( + "Invalid JSON-RPC response in client-channel interact: " + f"{err}" + ) + ) from err + + +def _no_data_error() -> DialException: + return DialException( + message="Client-channel interact stream closed without a data event", + status_code=HTTPStatus.GATEWAY_TIMEOUT, + ) + + +def _raise_if_batch_error(responses: Sequence[JsonRpcResponse]) -> None: + """Per JSON-RPC 2.0, a response with ``id=null`` indicates the server + could not associate the response with any request (parse error, invalid + batch, etc.). Surface that as a ``DialException`` instead of silently + mapping every toolset to ERROR. + """ + for r in responses: + if r.id is None and r.error is not None: + raise DialException( + message=( + f"Server-level JSON-RPC error " + f"({r.error.code}): {r.error.message}" + ), + status_code=HTTPStatus.BAD_GATEWAY, + ) + + +_RESULT_TO_OUTCOME = { + SigninResult.SUCCESS.value: SigninResult.SUCCESS, + SigninResult.DENIED.value: SigninResult.DENIED, +} + + +def _outcome_for(response: Optional[JsonRpcResponse]) -> SigninResult: + if response is None or response.error is not None: + return SigninResult.ERROR + if not isinstance(response.result, str): + return SigninResult.ERROR + return _RESULT_TO_OUTCOME.get(response.result, SigninResult.ERROR) + + +def _build_signin_requests( + toolset_ids: Sequence[str], +) -> List[JsonRpcRequest]: + return [ + JsonRpcRequest( + method=_SIGNIN_METHOD, + params={"toolsetId": tid}, + id=str(idx), + ) + for idx, tid in enumerate(toolset_ids, start=1) + ] + + +def _map_signin_results( + toolset_ids: Sequence[str], + responses: Sequence[JsonRpcResponse], +) -> "dict[str, SigninResult]": + by_id = {str(r.id): r for r in responses if r.id is not None} + return { + tid: _outcome_for(by_id.get(str(idx))) + for idx, tid in enumerate(toolset_ids, start=1) + } + + +class ClientChannel(Resource): + def signin_toolsets( + self, + *, + channel_id: str, + toolset_ids: Sequence[str], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> "dict[str, SigninResult]": + """Request interactive sign-in for one or more toolsets on the given + client channel and return the per-toolset outcome. + + ``toolset_ids`` are typically DIAL toolset ids (e.g. + ``"toolsets/public/my-toolset"``). The returned dict has one entry + per input id; toolsets for which the server does not produce a + response are mapped to :class:`SigninResult.ERROR`. Iteration order + of the returned dict matches the order of ``toolset_ids``. + + Raises :class:`InvalidRequestError` if ``toolset_ids`` is a plain + string or contains duplicates. Raises :class:`DialException` on HTTP + errors, transport failures, server-level JSON-RPC errors (e.g. parse + error returned with ``id=null``), or if the SSE stream closes + without a response event. + """ + ids = _normalize_toolset_ids(toolset_ids) + if not ids: + return {} + responses = self._interact( + channel_id=channel_id, + requests=_build_signin_requests(ids), + timeout=timeout, + ) + _raise_if_batch_error(responses) + return _map_signin_results(ids, responses) + + def _interact( + self, + *, + channel_id: str, + requests: Sequence[JsonRpcRequest], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> List[JsonRpcResponse]: + with self.http_client.stream_sse( + method="POST", + url=_INTERACT_URL, + json_data=_serialize_requests(requests), + headers={_CLIENT_CHANNEL_HEADER: channel_id}, + timeout=timeout, + ) as response: + for payload in iter_data_events(response.iter_lines()): + return _parse_responses(payload) + raise _no_data_error() + + +class AsyncClientChannel(AsyncResource): + async def signin_toolsets( + self, + *, + channel_id: str, + toolset_ids: Sequence[str], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> "dict[str, SigninResult]": + ids = _normalize_toolset_ids(toolset_ids) + if not ids: + return {} + responses = await self._interact( + channel_id=channel_id, + requests=_build_signin_requests(ids), + timeout=timeout, + ) + _raise_if_batch_error(responses) + return _map_signin_results(ids, responses) + + async def _interact( + self, + *, + channel_id: str, + requests: Sequence[JsonRpcRequest], + timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN, + ) -> List[JsonRpcResponse]: + async with self.http_client.stream_sse( + method="POST", + url=_INTERACT_URL, + json_data=_serialize_requests(requests), + headers={_CLIENT_CHANNEL_HEADER: channel_id}, + timeout=timeout, + ) as response: + async for payload in aiter_data_events(response.aiter_lines()): + return _parse_responses(payload) + raise _no_data_error() diff --git a/aidial_client/types/client_channel.py b/aidial_client/types/client_channel.py new file mode 100644 index 0000000..fa322db --- /dev/null +++ b/aidial_client/types/client_channel.py @@ -0,0 +1,9 @@ +from enum import Enum + + +class SigninResult(str, Enum): + """Outcome of an interactive sign-in request for a single toolset.""" + + SUCCESS = "success" + DENIED = "denied" + ERROR = "error" diff --git a/tests/resources/test_client_channel.py b/tests/resources/test_client_channel.py new file mode 100644 index 0000000..d85d911 --- /dev/null +++ b/tests/resources/test_client_channel.py @@ -0,0 +1,441 @@ +import json +import logging +from http import HTTPStatus +from typing import Any, List + +import httpx +import pytest + +from aidial_client import Dial, SigninResult +from aidial_client._client import AsyncDial +from aidial_client._exception import ( + DialException, + InvalidRequestError, + ParsingDataError, +) +from aidial_client._internal_types._json_rpc import JsonRpcRequest +from tests.client_mock import ( + MockStreamIterator, + get_async_client_mock, + get_client_mock, +) + + +def _sse_chunks(*lines: str) -> List[bytes]: + """Encode a sequence of SSE lines as one byte stream chunk.""" + return [("\n".join(lines) + "\n").encode()] + + +def _data(payload: Any) -> str: + return f"data: {json.dumps(payload)}" + + +def _single_event(payload: Any) -> List[bytes]: + return _sse_chunks(_data(payload), "") + + +def _signin_response(id_: str, result: str) -> dict: + return {"jsonrpc": "2.0", "id": id_, "result": result} + + +def _signin_error_response(id_: str, message: str = "boom") -> dict: + return { + "jsonrpc": "2.0", + "id": id_, + "error": {"code": -32000, "message": message}, + } + + +# ---------------------------------------------------------------------------- +# signin_toolsets — happy paths +# ---------------------------------------------------------------------------- + + +def test_signin_single_toolset_success_sync(): + client = get_client_mock( + status_code=200, + stream_chunks_mock=_single_event(_signin_response("1", "success")), + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["toolsets/public/a"] + ) + assert out == {"toolsets/public/a": SigninResult.SUCCESS} + + +@pytest.mark.asyncio +async def test_signin_single_toolset_success_async(): + client = get_async_client_mock( + status_code=200, + stream_chunks_mock=_single_event(_signin_response("1", "success")), + ) + out = await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["toolsets/public/a"] + ) + assert out == {"toolsets/public/a": SigninResult.SUCCESS} + + +def test_signin_batch_mixed_outcomes_sync(): + payload = [ + _signin_response("1", "success"), + _signin_response("2", "denied"), + _signin_error_response("3"), + ] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b", "c"] + ) + assert out == { + "a": SigninResult.SUCCESS, + "b": SigninResult.DENIED, + "c": SigninResult.ERROR, + } + + +@pytest.mark.asyncio +async def test_signin_batch_mixed_outcomes_async(): + payload = [ + _signin_response("1", "success"), + _signin_response("2", "denied"), + ] + client = get_async_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] + ) + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.DENIED} + + +def test_signin_out_of_order_responses_matched_by_id_sync(): + # Server returns responses in arrival order, NOT request order. + payload = [ + _signin_response("2", "denied"), + _signin_response("1", "success"), + ] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] + ) + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.DENIED} + + +def test_signin_missing_response_for_toolset_maps_to_error_sync(): + payload = [_signin_response("1", "success")] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] + ) + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.ERROR} + + +def test_signin_unknown_result_string_maps_to_error_sync(): + payload = [{"jsonrpc": "2.0", "id": "1", "result": "weird-value"}] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert out == {"a": SigninResult.ERROR} + + +def test_signin_empty_toolset_list_returns_empty_dict_sync(): + client = get_client_mock(status_code=200, stream_chunks_mock=[b""]) + out = client.client_channel.signin_toolsets(channel_id="ch", toolset_ids=[]) + assert out == {} + + +def test_signin_rejects_single_string_as_toolset_ids_sync(): + # A plain str satisfies Sequence[str] at runtime; reject explicitly so + # it doesn't iterate the string and send one request per character. + client = get_client_mock(status_code=200, stream_chunks_mock=[b""]) + with pytest.raises(InvalidRequestError): + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids="toolsets/public/x" + ) + + +def test_signin_rejects_duplicate_toolset_ids_sync(): + client = get_client_mock(status_code=200, stream_chunks_mock=[b""]) + with pytest.raises(InvalidRequestError): + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "a"] + ) + + +def test_signin_accepts_iterator_as_toolset_ids_sync(): + # A one-shot iterable would silently produce {} without materialization; + # the wrapper must list() the input before using it twice. + payload = [ + _signin_response("1", "success"), + _signin_response("2", "denied"), + ] + client = get_client_mock( + status_code=200, stream_chunks_mock=_single_event(payload) + ) + out = client.client_channel.signin_toolsets( + channel_id="ch", + toolset_ids=iter(["a", "b"]), # type: ignore[arg-type] + ) + assert out == {"a": SigninResult.SUCCESS, "b": SigninResult.DENIED} + + +def test_signin_batch_level_error_raises_dial_exception_sync(): + # Server-level JSON-RPC error: id=null with an error object (e.g. parse + # error -32700). Must raise instead of silently mapping all toolsets + # to SigninResult.ERROR. + chunks = _single_event( + { + "jsonrpc": "2.0", + "id": None, + "error": {"code": -32700, "message": "Parse error"}, + } + ) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert "Parse error" in exc_info.value.message + assert "-32700" in exc_info.value.message + + +# ---------------------------------------------------------------------------- +# signin_toolsets — transport errors +# ---------------------------------------------------------------------------- + + +def test_signin_no_data_event_raises_sync(): + chunks = _sse_chunks(": heartbeat", "", ": heartbeat", "") + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT + + +def test_signin_http_401_raises_with_message_sync(): + body = json.dumps({"error": {"message": "Unauthorized"}}).encode() + client = get_client_mock(status_code=401, stream_chunks_mock=[body]) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert exc_info.value.status_code == 401 + assert exc_info.value.message == "Unauthorized" + + +@pytest.mark.asyncio +async def test_signin_http_401_raises_async(): + body = json.dumps({"error": {"message": "Unauthorized"}}).encode() + client = get_async_client_mock(status_code=401, stream_chunks_mock=[body]) + with pytest.raises(DialException) as exc_info: + await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert exc_info.value.status_code == 401 + assert exc_info.value.message == "Unauthorized" + + +def test_signin_unknown_transport_error_wrapped_sync(): + client = get_client_mock( + status_code=200, exception_mock=httpx.ConnectError("boom") + ) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert "boom" in exc_info.value.message + assert "Request failed" in exc_info.value.message + + +def test_signin_timeout_wrapped_sync(): + client = get_client_mock( + status_code=200, exception_mock=httpx.ReadTimeout("slow") + ) + with pytest.raises(DialException) as exc_info: + client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a"] + ) + assert exc_info.value.status_code == HTTPStatus.REQUEST_TIMEOUT + + +# ---------------------------------------------------------------------------- +# Wire-format checks +# ---------------------------------------------------------------------------- + + +def test_signin_sends_channel_header_and_jsonrpc_body_sync(): + captured: dict = {} + + def send_mock(request: httpx.Request, **_kwargs): + captured["request"] = request + return httpx.Response( + status_code=200, + request=request, + stream=MockStreamIterator( + mock_chunks=_single_event(_signin_response("1", "success")) + ), + ) + + client = Dial(api_key="dummy", base_url="http://dial.core") + client._http_client._internal_http_client.send = send_mock + + client.client_channel.signin_toolsets( + channel_id="my-channel", toolset_ids=["toolsets/public/x"] + ) + + request = captured["request"] + assert request.headers["X-DIAL-CLIENT-CHANNEL-ID"] == "my-channel" + assert request.headers["api-key"] == "dummy" + assert request.url.path == "/v1/ops/client-channel/interact" + body = json.loads(request.content) + # Wire body is always an array, even for a single request. + assert body == [ + { + "jsonrpc": "2.0", + "method": "toolset/signin", + "params": {"toolsetId": "toolsets/public/x"}, + "id": "1", + } + ] + + +@pytest.mark.asyncio +async def test_signin_batch_body_serialized_as_array_async(): + captured: dict = {} + + async def send_mock(request: httpx.Request, **_kwargs): + captured["request"] = request + return httpx.Response( + status_code=200, + request=request, + stream=MockStreamIterator( + mock_chunks=_single_event( + [ + _signin_response("1", "success"), + _signin_response("2", "success"), + ] + ) + ), + ) + + client = AsyncDial(api_key="dummy", base_url="http://dial.core") + client._http_client._internal_http_client.send = send_mock + + await client.client_channel.signin_toolsets( + channel_id="ch", toolset_ids=["a", "b"] + ) + + body = json.loads(captured["request"].content) + assert isinstance(body, list) and len(body) == 2 + assert body[0]["params"] == {"toolsetId": "a"} + assert body[1]["params"] == {"toolsetId": "b"} + + +# ---------------------------------------------------------------------------- +# Internal _interact — protocol-level coverage +# ---------------------------------------------------------------------------- + + +def test_interact_result_null_is_valid_sync(): + # {result: null} is a successful response per JSON-RPC spec — must not + # raise ParsingDataError as the old code did. + chunks = _single_event({"jsonrpc": "2.0", "id": "1", "result": None}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + responses = client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + assert len(responses) == 1 + assert responses[0].result is None + assert responses[0].error is None + + +def test_interact_response_missing_id_raises_parsing_error_sync(): + chunks = _single_event({"jsonrpc": "2.0", "result": "ok"}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_response_with_both_result_and_error_raises_sync(): + chunks = _single_event( + { + "jsonrpc": "2.0", + "id": "1", + "result": "ok", + "error": {"code": -1, "message": "x"}, + } + ) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_response_with_neither_result_nor_error_raises_sync(): + chunks = _single_event({"jsonrpc": "2.0", "id": "1"}) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_malformed_json_raises_sync(): + chunks = _sse_chunks("data: not-json", "") + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with pytest.raises(ParsingDataError): + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + + +def test_interact_heartbeats_skipped_sync(): + chunks = _sse_chunks( + ": heartbeat", + "", + ": heartbeat", + "", + _data(_signin_response("1", "success")), + "", + ) + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + responses = client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + assert responses[0].result == "success" + + +def test_interact_truncated_stream_warns_and_no_phantom_event_sync(caplog): + # No trailing blank line — incomplete event must NOT be flushed, and a + # warning must be emitted by the SSE parser. + chunks = _sse_chunks('data: {"jsonrpc":"2.0","resu') + client = get_client_mock(status_code=200, stream_chunks_mock=chunks) + with caplog.at_level(logging.WARNING, logger="aidial_client"): + with pytest.raises(DialException) as exc_info: + client.client_channel._interact( + channel_id="ch", + requests=[JsonRpcRequest(jsonrpc="2.0", method="m", id="1")], + ) + assert exc_info.value.status_code == HTTPStatus.GATEWAY_TIMEOUT + assert any( + "Uncommitted data chunks in SSE stream" in rec.message + for rec in caplog.records + )