Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
51 changes: 51 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,8 @@
- [Get Toolset by Id](#get-toolset-by-id)
- [Resource Permissions](#resource-permissions)
- [Grant Permissions](#grant-permissions)
- [Client Channel](#client-channel)
- [Sign In to Toolsets](#sign-in-to-toolsets)
- [Client Pool](#client-pool)
- [Synchronous Client Pool](#synchronous-client-pool)
- [Asynchronous Client Pool](#asynchronous-client-pool)
Expand Down Expand Up @@ -854,6 +856,55 @@ await async_client.resource_permissions.grant(

The method returns `None` on success and raises `DialException` on HTTP error.

### Client Channel

DIAL Core's [client channel API](https://dialx.ai/universal_chat_api.yaml) lets a deployment ask an interactive client (e.g. the chat UI) to take some action and report the result back. The channel id is propagated to the deployment via the `X-DIAL-CLIENT-CHANNEL-ID` forwarded header on the inbound request.

#### Sign In to Toolsets

Use `client_channel.signin_toolsets()` to request interactive sign-in for one or more toolsets on the active client channel. The method returns a `dict[str, SigninResult]` mapping each input toolset id to its outcome — responses are correlated by the client, so the caller never has to deal with the underlying JSON-RPC ids.

```python
from aidial_client import SigninResult

# Sync
results = client.client_channel.signin_toolsets(
channel_id="<channel-id-from-X-DIAL-CLIENT-CHANNEL-ID>",
toolset_ids=[
"toolsets/public/toolset-a",
"toolsets/public/toolset-b",
],
timeout=120.0,
)

# Async
results = await async_client.client_channel.signin_toolsets(
channel_id="<channel-id>",
toolset_ids=["toolsets/public/my-toolset"],
)
```

Each value is a `SigninResult` enum:

```python
{
"toolsets/public/toolset-a": SigninResult.SUCCESS,
"toolsets/public/toolset-b": SigninResult.DENIED,
}
```

- `SigninResult.SUCCESS` — the user signed in.
- `SigninResult.DENIED` — the user declined.
- `SigninResult.ERROR` — the server returned a JSON-RPC error, or the response was missing/unrecognized.

Arguments:

- `channel_id` — required; the channel id received via the `X-DIAL-CLIENT-CHANNEL-ID` header on the inbound request.
- `toolset_ids` — sequence of toolset ids to request sign-in for; an empty sequence returns `{}` without contacting the server.
- `timeout` — optional `float` seconds or `httpx.Timeout`; defaults to the client-wide timeout. Useful for interactive flows where the user may take a while to respond.

Raises `DialException` on HTTP errors (e.g. unauthorized, missing channel), transport failures (timeouts, network errors), or if the SSE stream closes without a response event.

### Client Pool

When you need to create multiple DIAL clients and wish to enhance performance by reusing the HTTP connection for the same DIAL instance, consider using synchronous and asynchronous **client pools**.
Expand Down
2 changes: 2 additions & 0 deletions aidial_client/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
ParsingDataError,
ResourceNotFoundError,
)
from aidial_client.types.client_channel import SigninResult
from aidial_client.types.model import ModelInfo, ModelLimits, ModelPricing
from aidial_client.types.toolset import ToolsetInfo

Expand All @@ -30,4 +31,5 @@
"ModelInfo",
"ModelPricing",
"ModelLimits",
"SigninResult",
]
6 changes: 6 additions & 0 deletions aidial_client/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,9 @@ def _init_resources(self) -> None:
self.resource_permissions = resources.ResourcePermissions(
http_client=self._http_client
)
self.client_channel = resources.ClientChannel(
http_client=self._http_client
)

def _create_http_client(self) -> SyncHTTPClient:
return SyncHTTPClient(
Expand Down Expand Up @@ -207,6 +210,9 @@ def _init_resources(self) -> None:
self.resource_permissions = resources.AsyncResourcePermissions(
http_client=self._http_client
)
self.client_channel = resources.AsyncClientChannel(
http_client=self._http_client
)

def _create_http_client(self) -> AsyncHTTPClient:
return AsyncHTTPClient(
Expand Down
64 changes: 63 additions & 1 deletion aidial_client/_http_client/_async.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,23 @@
import asyncio
from contextlib import asynccontextmanager
from http import HTTPStatus
from typing import Callable, Dict, Optional, Type
from typing import (
Any,
AsyncIterator,
Callable,
Dict,
Mapping,
Optional,
Type,
Union,
)

import httpx

from aidial_client._auth import AsyncAuthValue, aget_combined_auth_headers
from aidial_client._exception import DialException
from aidial_client._http_client._base import BaseHTTPClient
from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven
from aidial_client._internal_types._generic import ResponseT
from aidial_client._internal_types._http_request import FinalRequestOptions
from aidial_client._log import logger
Expand Down Expand Up @@ -108,3 +119,54 @@ async def request(
raise raised_error from err

return process_block_response(cast_to=cast_to, response=response)

@asynccontextmanager
async def stream_sse(
self,
*,
method: str,
url: str,
json_data: Any,
headers: Optional[Mapping[str, str]] = None,
timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN,
) -> AsyncIterator[httpx.Response]:
"""Open an SSE streaming response. Yields the open httpx.Response.

Auth headers are merged in. On non-2xx, reads the body and raises
a DialException; transport errors (timeouts, network failures) are
also wrapped so the caller always sees DialException. Retries are
not performed for streaming requests.

``timeout`` defaults to the client-wide timeout; pass an explicit
``None`` (or ``httpx.Timeout(None)``) for no timeout.
"""
merged_headers = {**(await self.auth_headers()), **(headers or {})}
effective_timeout = (
self._timeout if isinstance(timeout, NotGiven) else timeout
)
try:
async with self._internal_http_client.stream(
method=method,
url=self._prepare_url(url),
headers=merged_headers,
json=json_data,
timeout=effective_timeout,
) as response:
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
try:
await response.aread()
except httpx.HTTPError:
pass
raise self._make_dial_error_from_response(
err.response
) from err
yield response
except httpx.TimeoutException as err:
Comment thread
adubovik marked this conversation as resolved.
raise DialException(
message="Request timed out",
status_code=HTTPStatus.REQUEST_TIMEOUT,
) from err
except httpx.HTTPError as err:
raise DialException(message=f"Request failed: {err}") from err
47 changes: 47 additions & 0 deletions aidial_client/_http_client/_sse.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
from typing import AsyncIterator, Iterator, List

from aidial_client._log import logger

_UNCOMMITTED_BUFFER_WARNING = (
"Uncommitted data chunks in SSE stream "
"(stream ended without a terminating blank line); discarding."
)


def _strip_field(line: str, prefix: str) -> str:
"""Strip a single leading U+0020 SPACE after the field colon, per the SSE spec."""
value = line[len(prefix) :]
return value[1:] if value.startswith(" ") else value


def iter_data_events(lines: Iterator[str]) -> Iterator[str]:
"""Yield the payload of each complete ``data:`` event from an SSE line stream.

An event is complete when a blank line follows the ``data:`` line(s). Per
the SSE dispatch rule, a buffer that has not been terminated by a blank
line is discarded (we do NOT flush partial events at end of stream).
Comment lines (``:``) and other field names are ignored.
"""
buffer: List[str] = []
for line in lines:
if line == "":
if buffer:
yield "\n".join(buffer)
buffer = []
elif line.startswith("data:"):
buffer.append(_strip_field(line, "data:"))
Comment thread
adubovik marked this conversation as resolved.
if buffer:
logger.warning(_UNCOMMITTED_BUFFER_WARNING)


async def aiter_data_events(lines: AsyncIterator[str]) -> AsyncIterator[str]:
buffer: List[str] = []
async for line in lines:
if line == "":
if buffer:
yield "\n".join(buffer)
buffer = []
elif line.startswith("data:"):
buffer.append(_strip_field(line, "data:"))
if buffer:
logger.warning(_UNCOMMITTED_BUFFER_WARNING)
55 changes: 54 additions & 1 deletion aidial_client/_http_client/_sync.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,14 @@
import time
from contextlib import contextmanager
from http import HTTPStatus
from typing import Callable, Dict, Optional, Type
from typing import Any, Callable, Dict, Iterator, Mapping, Optional, Type, Union

import httpx

from aidial_client._auth import SyncAuthValue, get_combined_auth_headers
from aidial_client._exception import DialException
from aidial_client._http_client._base import BaseHTTPClient
from aidial_client._internal_types._defaults import NOT_GIVEN, NotGiven
from aidial_client._internal_types._generic import ResponseT
from aidial_client._internal_types._http_request import FinalRequestOptions
from aidial_client._log import logger
Expand Down Expand Up @@ -108,3 +110,54 @@ def request(
raise raised_error from err

return process_block_response(cast_to=cast_to, response=response)

@contextmanager
def stream_sse(
self,
*,
method: str,
url: str,
json_data: Any,
headers: Optional[Mapping[str, str]] = None,
timeout: Union[float, httpx.Timeout, None, NotGiven] = NOT_GIVEN,
) -> Iterator[httpx.Response]:
"""Open an SSE streaming response. Yields the open httpx.Response.

Auth headers are merged in. On non-2xx, reads the body and raises
a DialException; transport errors (timeouts, network failures) are
also wrapped so the caller always sees DialException. Retries are
not performed for streaming requests.

``timeout`` defaults to the client-wide timeout; pass an explicit
``None`` (or ``httpx.Timeout(None)``) for no timeout.
"""
merged_headers = {**self.auth_headers(), **(headers or {})}
effective_timeout = (
self._timeout if isinstance(timeout, NotGiven) else timeout
)
try:
with self._internal_http_client.stream(
method=method,
url=self._prepare_url(url),
headers=merged_headers,
json=json_data,
timeout=effective_timeout,
) as response:
try:
response.raise_for_status()
except httpx.HTTPStatusError as err:
try:
response.read()
except httpx.HTTPError:
pass
raise self._make_dial_error_from_response(
err.response
) from err
yield response
except httpx.TimeoutException as err:
raise DialException(
message="Request timed out",
status_code=HTTPStatus.REQUEST_TIMEOUT,
) from err
except httpx.HTTPError as err:
raise DialException(message=f"Request failed: {err}") from err
75 changes: 75 additions & 0 deletions aidial_client/_internal_types/_json_rpc.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,75 @@
from typing import Any, Dict, List, Literal, Optional, Union

from aidial_client._compatibility.pydantic_v1 import (
BaseModel,
Extra,
Field,
root_validator,
)


class JsonRpcError(BaseModel):
code: int
message: str
data: Optional[Any] = None

class Config:
extra = Extra.allow


class JsonRpcRequest(BaseModel):
jsonrpc: Literal["2.0"] = "2.0"
method: str
params: Optional[Union[List[Any], Dict[str, Any]]] = None
id: Optional[Union[int, str]] = None

class Config:
smart_union = True


class JsonRpcResponse(BaseModel):
jsonrpc: Literal["2.0"]
result: Optional[Any] = None
error: Optional[JsonRpcError] = None
id: Optional[Union[int, str]] = Field(...)

class Config:
smart_union = True
extra = Extra.allow

@root_validator(pre=True)
def _validate_result_xor_error(cls, values):
"""Per JSON-RPC 2.0 (https://www.jsonrpc.org/specification#response_object),
either ``result`` or ``error`` MUST be included (presence-wise — ``null``
is a valid result value), and both MUST NOT be included.
"""
if not isinstance(values, dict):
return values
has_result = "result" in values
has_error = "error" in values
if has_result and has_error:
raise ValueError(
"JSON-RPC response must not contain both 'result' and 'error'"
)
if not has_result and not has_error:
raise ValueError(
"JSON-RPC response must contain either 'result' or 'error'"
)
return values


class JsonRpcResponses(BaseModel):
"""Pydantic root model that accepts a single JSON-RPC response object or
a batch array, normalizing both to a list via the ``responses`` property.
"""

__root__: Union[JsonRpcResponse, List[JsonRpcResponse]]

class Config:
smart_union = True

@property
def responses(self) -> List[JsonRpcResponse]:
if isinstance(self.__root__, list):
return self.__root__
return [self.__root__]
6 changes: 6 additions & 0 deletions aidial_client/resources/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from aidial_client.resources.client_channel import (
AsyncClientChannel,
ClientChannel,
)
from aidial_client.resources.deployments import AsyncDeployments, Deployments
from aidial_client.resources.metadata import AsyncMetadata, Metadata
from aidial_client.resources.model import AsyncModel, Model
Expand Down Expand Up @@ -34,4 +38,6 @@
"AsyncModel",
"ResourcePermissions",
"AsyncResourcePermissions",
"ClientChannel",
"AsyncClientChannel",
]
Loading
Loading