diff --git a/.release-please-manifest.json b/.release-please-manifest.json index b5db7ce1..c373724d 100644 --- a/.release-please-manifest.json +++ b/.release-please-manifest.json @@ -1,3 +1,3 @@ { - ".": "0.1.0-alpha.7" + ".": "0.1.0-alpha.8" } \ No newline at end of file diff --git a/.stats.yml b/.stats.yml index 79a36ab0..8eb4144d 100644 --- a/.stats.yml +++ b/.stats.yml @@ -1,4 +1,4 @@ configured_endpoints: 76 openapi_spec_url: https://storage.googleapis.com/stainless-sdk-openapi-specs/digitalocean%2Fgradientai-e8b3cbc80e18e4f7f277010349f25e1319156704f359911dc464cc21a0d077a6.yml openapi_spec_hash: c773d792724f5647ae25a5ae4ccec208 -config_hash: f0976fbc552ea878bb527447b5e663c9 +config_hash: e1b3d85ba9ae21d729a914c789422ba7 diff --git a/CHANGELOG.md b/CHANGELOG.md index 15fec91a..b26d4058 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,13 @@ # Changelog +## 0.1.0-alpha.8 (2025-06-27) + +Full Changelog: [v0.1.0-alpha.7...v0.1.0-alpha.8](https://github.com/digitalocean/gradientai-python/compare/v0.1.0-alpha.7...v0.1.0-alpha.8) + +### Features + +* **client:** setup streaming ([3fd6e57](https://github.com/digitalocean/gradientai-python/commit/3fd6e575f6f5952860e42d8c1fa22ccb0b10c623)) + ## 0.1.0-alpha.7 (2025-06-27) Full Changelog: [v0.1.0-alpha.6...v0.1.0-alpha.7](https://github.com/digitalocean/gradientai-python/compare/v0.1.0-alpha.6...v0.1.0-alpha.7) diff --git a/api.md b/api.md index dc48f7b3..b1ac8b43 100644 --- a/api.md +++ b/api.md @@ -65,7 +65,7 @@ Methods: Types: ```python -from gradientai.types.agents.chat import CompletionCreateResponse +from gradientai.types.agents.chat import ChatCompletionChunk, CompletionCreateResponse ``` Methods: diff --git a/pyproject.toml b/pyproject.toml index 29531941..60a58f89 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "c63a5cfe-b235-4fbe-8bbb-82a9e02a482a-python" -version = "0.1.0-alpha.7" +version = "0.1.0-alpha.8" description = "The official Python library for GradientAI" dynamic = ["readme"] license = "Apache-2.0" diff --git a/src/gradientai/_client.py b/src/gradientai/_client.py index 327273c9..939d8c6f 100644 --- a/src/gradientai/_client.py +++ b/src/gradientai/_client.py @@ -117,6 +117,8 @@ def __init__( _strict_response_validation=_strict_response_validation, ) + self._default_stream_cls = Stream + @cached_property def agents(self) -> AgentsResource: from .resources.agents import AgentsResource @@ -355,6 +357,8 @@ def __init__( _strict_response_validation=_strict_response_validation, ) + self._default_stream_cls = AsyncStream + @cached_property def agents(self) -> AsyncAgentsResource: from .resources.agents import AsyncAgentsResource diff --git a/src/gradientai/_streaming.py b/src/gradientai/_streaming.py index bab5eb80..69a805ad 100644 --- a/src/gradientai/_streaming.py +++ b/src/gradientai/_streaming.py @@ -9,7 +9,8 @@ import httpx -from ._utils import extract_type_var_from_base +from ._utils import is_mapping, extract_type_var_from_base +from ._exceptions import APIError if TYPE_CHECKING: from ._client import GradientAI, AsyncGradientAI @@ -55,7 +56,25 @@ def __stream__(self) -> Iterator[_T]: iterator = self._iter_events() for sse in iterator: - yield process_data(data=sse.json(), cast_to=cast_to, response=response) + if sse.data.startswith("[DONE]"): + break + + data = sse.json() + if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=self.response.request, + body=data["error"], + ) + + yield process_data(data=data, cast_to=cast_to, response=response) # Ensure the entire stream is consumed for _sse in iterator: @@ -119,7 +138,25 @@ async def __stream__(self) -> AsyncIterator[_T]: iterator = self._iter_events() async for sse in iterator: - yield process_data(data=sse.json(), cast_to=cast_to, response=response) + if sse.data.startswith("[DONE]"): + break + + data = sse.json() + if is_mapping(data) and data.get("error"): + message = None + error = data.get("error") + if is_mapping(error): + message = error.get("message") + if not message or not isinstance(message, str): + message = "An error occurred during streaming" + + raise APIError( + message=message, + request=self.response.request, + body=data["error"], + ) + + yield process_data(data=data, cast_to=cast_to, response=response) # Ensure the entire stream is consumed async for _sse in iterator: diff --git a/src/gradientai/_version.py b/src/gradientai/_version.py index d4e6dde6..8c8f2b63 100644 --- a/src/gradientai/_version.py +++ b/src/gradientai/_version.py @@ -1,4 +1,4 @@ # File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. __title__ = "gradientai" -__version__ = "0.1.0-alpha.7" # x-release-please-version +__version__ = "0.1.0-alpha.8" # x-release-please-version diff --git a/src/gradientai/resources/agents/chat/completions.py b/src/gradientai/resources/agents/chat/completions.py index a213bf05..92431cdf 100644 --- a/src/gradientai/resources/agents/chat/completions.py +++ b/src/gradientai/resources/agents/chat/completions.py @@ -3,11 +3,12 @@ from __future__ import annotations from typing import Dict, List, Union, Iterable, Optional +from typing_extensions import Literal, overload import httpx from ...._types import NOT_GIVEN, Body, Query, Headers, NotGiven -from ...._utils import maybe_transform, async_maybe_transform +from ...._utils import required_args, maybe_transform, async_maybe_transform from ...._compat import cached_property from ...._resource import SyncAPIResource, AsyncAPIResource from ...._response import ( @@ -16,8 +17,10 @@ async_to_raw_response_wrapper, async_to_streamed_response_wrapper, ) +from ...._streaming import Stream, AsyncStream from ...._base_client import make_request_options from ....types.agents.chat import completion_create_params +from ....types.agents.chat.chat_completion_chunk import ChatCompletionChunk from ....types.agents.chat.completion_create_response import CompletionCreateResponse __all__ = ["CompletionsResource", "AsyncCompletionsResource"] @@ -43,6 +46,7 @@ def with_streaming_response(self) -> CompletionsResourceWithStreamingResponse: """ return CompletionsResourceWithStreamingResponse(self) + @overload def create( self, *, @@ -57,7 +61,7 @@ def create( n: Optional[int] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[bool] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, @@ -153,6 +157,262 @@ def create( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + stream: Literal[True], + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> Stream[ChatCompletionChunk]: + """ + Creates a model response for the given chat conversation. + + Args: + messages: A list of messages comprising the conversation so far. + + model: Model ID used to generate the response. + + stream: If set to true, the model response data will be streamed to the client as it is + generated using server-sent events. + + frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. + + logit_bias: Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the + tokenizer) to an associated bias value from -100 to 100. Mathematically, the + bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, but values between -1 and 1 should decrease or + increase likelihood of selection; values like -100 or 100 should result in a ban + or exclusive selection of the relevant token. + + logprobs: Whether to return log probabilities of the output tokens or not. If true, + returns the log probabilities of each output token returned in the `content` of + `message`. + + max_completion_tokens: The maximum number of completion tokens that may be used over the course of the + run. The run will make a best effort to use only the number of completion tokens + specified, across multiple turns of the run. + + max_tokens: The maximum number of tokens that can be generated in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's + context length. + + metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful + for storing additional information about the object in a structured format, and + querying for objects via API or the dashboard. + + Keys are strings with a maximum length of 64 characters. Values are strings with + a maximum length of 512 characters. + + n: How many chat completion choices to generate for each input message. Note that + you will be charged based on the number of generated tokens across all of the + choices. Keep `n` as `1` to minimize costs. + + presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood to + talk about new topics. + + stop: Up to 4 sequences where the API will stop generating further tokens. The + returned text will not contain the stop sequence. + + stream_options: Options for streaming response. Only set this when you set `stream: true`. + + temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + make the output more random, while lower values like 0.2 will make it more + focused and deterministic. We generally recommend altering this or `top_p` but + not both. + + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. + + top_p: An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + + user: A unique identifier representing your end-user, which can help DigitalOcean to + monitor and detect abuse. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + stream: bool, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> CompletionCreateResponse | Stream[ChatCompletionChunk]: + """ + Creates a model response for the given chat conversation. + + Args: + messages: A list of messages comprising the conversation so far. + + model: Model ID used to generate the response. + + stream: If set to true, the model response data will be streamed to the client as it is + generated using server-sent events. + + frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. + + logit_bias: Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the + tokenizer) to an associated bias value from -100 to 100. Mathematically, the + bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, but values between -1 and 1 should decrease or + increase likelihood of selection; values like -100 or 100 should result in a ban + or exclusive selection of the relevant token. + + logprobs: Whether to return log probabilities of the output tokens or not. If true, + returns the log probabilities of each output token returned in the `content` of + `message`. + + max_completion_tokens: The maximum number of completion tokens that may be used over the course of the + run. The run will make a best effort to use only the number of completion tokens + specified, across multiple turns of the run. + + max_tokens: The maximum number of tokens that can be generated in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's + context length. + + metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful + for storing additional information about the object in a structured format, and + querying for objects via API or the dashboard. + + Keys are strings with a maximum length of 64 characters. Values are strings with + a maximum length of 512 characters. + + n: How many chat completion choices to generate for each input message. Note that + you will be charged based on the number of generated tokens across all of the + choices. Keep `n` as `1` to minimize costs. + + presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood to + talk about new topics. + + stop: Up to 4 sequences where the API will stop generating further tokens. The + returned text will not contain the stop sequence. + + stream_options: Options for streaming response. Only set this when you set `stream: true`. + + temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + make the output more random, while lower values like 0.2 will make it more + focused and deterministic. We generally recommend altering this or `top_p` but + not both. + + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. + + top_p: An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + + user: A unique identifier representing your end-user, which can help DigitalOcean to + monitor and detect abuse. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model"], ["messages", "model", "stream"]) + def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> CompletionCreateResponse | Stream[ChatCompletionChunk]: return self._post( "/chat/completions" if self._client._base_url_overridden @@ -177,12 +437,16 @@ def create( "top_p": top_p, "user": user, }, - completion_create_params.CompletionCreateParams, + completion_create_params.CompletionCreateParamsStreaming + if stream + else completion_create_params.CompletionCreateParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=CompletionCreateResponse, + stream=stream or False, + stream_cls=Stream[ChatCompletionChunk], ) @@ -206,6 +470,7 @@ def with_streaming_response(self) -> AsyncCompletionsResourceWithStreamingRespon """ return AsyncCompletionsResourceWithStreamingResponse(self) + @overload async def create( self, *, @@ -220,7 +485,7 @@ async def create( n: Optional[int] | NotGiven = NOT_GIVEN, presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, - stream: Optional[bool] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | NotGiven = NOT_GIVEN, stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, temperature: Optional[float] | NotGiven = NOT_GIVEN, top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, @@ -316,6 +581,262 @@ async def create( timeout: Override the client-level default timeout for this request, in seconds """ + ... + + @overload + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + stream: Literal[True], + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> AsyncStream[ChatCompletionChunk]: + """ + Creates a model response for the given chat conversation. + + Args: + messages: A list of messages comprising the conversation so far. + + model: Model ID used to generate the response. + + stream: If set to true, the model response data will be streamed to the client as it is + generated using server-sent events. + + frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. + + logit_bias: Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the + tokenizer) to an associated bias value from -100 to 100. Mathematically, the + bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, but values between -1 and 1 should decrease or + increase likelihood of selection; values like -100 or 100 should result in a ban + or exclusive selection of the relevant token. + + logprobs: Whether to return log probabilities of the output tokens or not. If true, + returns the log probabilities of each output token returned in the `content` of + `message`. + + max_completion_tokens: The maximum number of completion tokens that may be used over the course of the + run. The run will make a best effort to use only the number of completion tokens + specified, across multiple turns of the run. + + max_tokens: The maximum number of tokens that can be generated in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's + context length. + + metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful + for storing additional information about the object in a structured format, and + querying for objects via API or the dashboard. + + Keys are strings with a maximum length of 64 characters. Values are strings with + a maximum length of 512 characters. + + n: How many chat completion choices to generate for each input message. Note that + you will be charged based on the number of generated tokens across all of the + choices. Keep `n` as `1` to minimize costs. + + presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood to + talk about new topics. + + stop: Up to 4 sequences where the API will stop generating further tokens. The + returned text will not contain the stop sequence. + + stream_options: Options for streaming response. Only set this when you set `stream: true`. + + temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + make the output more random, while lower values like 0.2 will make it more + focused and deterministic. We generally recommend altering this or `top_p` but + not both. + + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. + + top_p: An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + + user: A unique identifier representing your end-user, which can help DigitalOcean to + monitor and detect abuse. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @overload + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + stream: bool, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> CompletionCreateResponse | AsyncStream[ChatCompletionChunk]: + """ + Creates a model response for the given chat conversation. + + Args: + messages: A list of messages comprising the conversation so far. + + model: Model ID used to generate the response. + + stream: If set to true, the model response data will be streamed to the client as it is + generated using server-sent events. + + frequency_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on their + existing frequency in the text so far, decreasing the model's likelihood to + repeat the same line verbatim. + + logit_bias: Modify the likelihood of specified tokens appearing in the completion. + + Accepts a JSON object that maps tokens (specified by their token ID in the + tokenizer) to an associated bias value from -100 to 100. Mathematically, the + bias is added to the logits generated by the model prior to sampling. The exact + effect will vary per model, but values between -1 and 1 should decrease or + increase likelihood of selection; values like -100 or 100 should result in a ban + or exclusive selection of the relevant token. + + logprobs: Whether to return log probabilities of the output tokens or not. If true, + returns the log probabilities of each output token returned in the `content` of + `message`. + + max_completion_tokens: The maximum number of completion tokens that may be used over the course of the + run. The run will make a best effort to use only the number of completion tokens + specified, across multiple turns of the run. + + max_tokens: The maximum number of tokens that can be generated in the completion. + + The token count of your prompt plus `max_tokens` cannot exceed the model's + context length. + + metadata: Set of 16 key-value pairs that can be attached to an object. This can be useful + for storing additional information about the object in a structured format, and + querying for objects via API or the dashboard. + + Keys are strings with a maximum length of 64 characters. Values are strings with + a maximum length of 512 characters. + + n: How many chat completion choices to generate for each input message. Note that + you will be charged based on the number of generated tokens across all of the + choices. Keep `n` as `1` to minimize costs. + + presence_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens based on + whether they appear in the text so far, increasing the model's likelihood to + talk about new topics. + + stop: Up to 4 sequences where the API will stop generating further tokens. The + returned text will not contain the stop sequence. + + stream_options: Options for streaming response. Only set this when you set `stream: true`. + + temperature: What sampling temperature to use, between 0 and 2. Higher values like 0.8 will + make the output more random, while lower values like 0.2 will make it more + focused and deterministic. We generally recommend altering this or `top_p` but + not both. + + top_logprobs: An integer between 0 and 20 specifying the number of most likely tokens to + return at each token position, each with an associated log probability. + `logprobs` must be set to `true` if this parameter is used. + + top_p: An alternative to sampling with temperature, called nucleus sampling, where the + model considers the results of the tokens with top_p probability mass. So 0.1 + means only the tokens comprising the top 10% probability mass are considered. + + We generally recommend altering this or `temperature` but not both. + + user: A unique identifier representing your end-user, which can help DigitalOcean to + monitor and detect abuse. + + extra_headers: Send extra headers + + extra_query: Add additional query parameters to the request + + extra_body: Add additional JSON properties to the request + + timeout: Override the client-level default timeout for this request, in seconds + """ + ... + + @required_args(["messages", "model"], ["messages", "model", "stream"]) + async def create( + self, + *, + messages: Iterable[completion_create_params.Message], + model: str, + frequency_penalty: Optional[float] | NotGiven = NOT_GIVEN, + logit_bias: Optional[Dict[str, int]] | NotGiven = NOT_GIVEN, + logprobs: Optional[bool] | NotGiven = NOT_GIVEN, + max_completion_tokens: Optional[int] | NotGiven = NOT_GIVEN, + max_tokens: Optional[int] | NotGiven = NOT_GIVEN, + metadata: Optional[Dict[str, str]] | NotGiven = NOT_GIVEN, + n: Optional[int] | NotGiven = NOT_GIVEN, + presence_penalty: Optional[float] | NotGiven = NOT_GIVEN, + stop: Union[Optional[str], List[str], None] | NotGiven = NOT_GIVEN, + stream: Optional[Literal[False]] | Literal[True] | NotGiven = NOT_GIVEN, + stream_options: Optional[completion_create_params.StreamOptions] | NotGiven = NOT_GIVEN, + temperature: Optional[float] | NotGiven = NOT_GIVEN, + top_logprobs: Optional[int] | NotGiven = NOT_GIVEN, + top_p: Optional[float] | NotGiven = NOT_GIVEN, + user: str | NotGiven = NOT_GIVEN, + # Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs. + # The extra values given here take precedence over values defined on the client or passed to this method. + extra_headers: Headers | None = None, + extra_query: Query | None = None, + extra_body: Body | None = None, + timeout: float | httpx.Timeout | None | NotGiven = NOT_GIVEN, + ) -> CompletionCreateResponse | AsyncStream[ChatCompletionChunk]: return await self._post( "/chat/completions" if self._client._base_url_overridden @@ -340,12 +861,16 @@ async def create( "top_p": top_p, "user": user, }, - completion_create_params.CompletionCreateParams, + completion_create_params.CompletionCreateParamsStreaming + if stream + else completion_create_params.CompletionCreateParamsNonStreaming, ), options=make_request_options( extra_headers=extra_headers, extra_query=extra_query, extra_body=extra_body, timeout=timeout ), cast_to=CompletionCreateResponse, + stream=stream or False, + stream_cls=AsyncStream[ChatCompletionChunk], ) diff --git a/src/gradientai/types/agents/chat/__init__.py b/src/gradientai/types/agents/chat/__init__.py index 9384ac14..f0243162 100644 --- a/src/gradientai/types/agents/chat/__init__.py +++ b/src/gradientai/types/agents/chat/__init__.py @@ -2,5 +2,6 @@ from __future__ import annotations +from .chat_completion_chunk import ChatCompletionChunk as ChatCompletionChunk from .completion_create_params import CompletionCreateParams as CompletionCreateParams from .completion_create_response import CompletionCreateResponse as CompletionCreateResponse diff --git a/src/gradientai/types/agents/chat/chat_completion_chunk.py b/src/gradientai/types/agents/chat/chat_completion_chunk.py new file mode 100644 index 00000000..b81aef72 --- /dev/null +++ b/src/gradientai/types/agents/chat/chat_completion_chunk.py @@ -0,0 +1,93 @@ +# File generated from our OpenAPI spec by Stainless. See CONTRIBUTING.md for details. + +from typing import List, Optional +from typing_extensions import Literal + +from ...._models import BaseModel +from ...shared.chat_completion_token_logprob import ChatCompletionTokenLogprob + +__all__ = ["ChatCompletionChunk", "Choice", "ChoiceDelta", "ChoiceLogprobs", "Usage"] + + +class ChoiceDelta(BaseModel): + content: Optional[str] = None + """The contents of the chunk message.""" + + refusal: Optional[str] = None + """The refusal message generated by the model.""" + + role: Optional[Literal["developer", "user", "assistant"]] = None + """The role of the author of this message.""" + + +class ChoiceLogprobs(BaseModel): + content: Optional[List[ChatCompletionTokenLogprob]] = None + """A list of message content tokens with log probability information.""" + + refusal: Optional[List[ChatCompletionTokenLogprob]] = None + """A list of message refusal tokens with log probability information.""" + + +class Choice(BaseModel): + delta: ChoiceDelta + """A chat completion delta generated by streamed model responses.""" + + finish_reason: Optional[Literal["stop", "length"]] = None + """The reason the model stopped generating tokens. + + This will be `stop` if the model hit a natural stop point or a provided stop + sequence, or `length` if the maximum number of tokens specified in the request + was reached + """ + + index: int + """The index of the choice in the list of choices.""" + + logprobs: Optional[ChoiceLogprobs] = None + """Log probability information for the choice.""" + + +class Usage(BaseModel): + completion_tokens: int + """Number of tokens in the generated completion.""" + + prompt_tokens: int + """Number of tokens in the prompt.""" + + total_tokens: int + """Total number of tokens used in the request (prompt + completion).""" + + +class ChatCompletionChunk(BaseModel): + id: str + """A unique identifier for the chat completion. Each chunk has the same ID.""" + + choices: List[Choice] + """A list of chat completion choices. + + Can contain more than one elements if `n` is greater than 1. Can also be empty + for the last chunk if you set `stream_options: {"include_usage": true}`. + """ + + created: int + """The Unix timestamp (in seconds) of when the chat completion was created. + + Each chunk has the same timestamp. + """ + + model: str + """The model to generate the completion.""" + + object: Literal["chat.completion.chunk"] + """The object type, which is always `chat.completion.chunk`.""" + + usage: Optional[Usage] = None + """ + An optional field that will only be present when you set + `stream_options: {"include_usage": true}` in your request. When present, it + contains a null value **except for the last chunk** which contains the token + usage statistics for the entire request. + + **NOTE:** If the stream is interrupted or cancelled, you may not receive the + final usage chunk which contains the total token usage for the request. + """ diff --git a/src/gradientai/types/agents/chat/completion_create_params.py b/src/gradientai/types/agents/chat/completion_create_params.py index 11d032ff..ec5c6b70 100644 --- a/src/gradientai/types/agents/chat/completion_create_params.py +++ b/src/gradientai/types/agents/chat/completion_create_params.py @@ -6,17 +6,19 @@ from typing_extensions import Literal, Required, TypeAlias, TypedDict __all__ = [ - "CompletionCreateParams", + "CompletionCreateParamsBase", "Message", "MessageChatCompletionRequestSystemMessage", "MessageChatCompletionRequestDeveloperMessage", "MessageChatCompletionRequestUserMessage", "MessageChatCompletionRequestAssistantMessage", "StreamOptions", + "CompletionCreateParamsNonStreaming", + "CompletionCreateParamsStreaming", ] -class CompletionCreateParams(TypedDict, total=False): +class CompletionCreateParamsBase(TypedDict, total=False): messages: Required[Iterable[Message]] """A list of messages comprising the conversation so far.""" @@ -92,12 +94,6 @@ class CompletionCreateParams(TypedDict, total=False): The returned text will not contain the stop sequence. """ - stream: Optional[bool] - """ - If set to true, the model response data will be streamed to the client as it is - generated using server-sent events. - """ - stream_options: Optional[StreamOptions] """Options for streaming response. Only set this when you set `stream: true`.""" @@ -183,3 +179,22 @@ class StreamOptions(TypedDict, total=False): **NOTE:** If the stream is interrupted, you may not receive the final usage chunk which contains the total token usage for the request. """ + + +class CompletionCreateParamsNonStreaming(CompletionCreateParamsBase, total=False): + stream: Optional[Literal[False]] + """ + If set to true, the model response data will be streamed to the client as it is + generated using server-sent events. + """ + + +class CompletionCreateParamsStreaming(CompletionCreateParamsBase): + stream: Required[Literal[True]] + """ + If set to true, the model response data will be streamed to the client as it is + generated using server-sent events. + """ + + +CompletionCreateParams = Union[CompletionCreateParamsNonStreaming, CompletionCreateParamsStreaming] diff --git a/tests/api_resources/agents/chat/test_completions.py b/tests/api_resources/agents/chat/test_completions.py index 89d531a5..4630adfc 100644 --- a/tests/api_resources/agents/chat/test_completions.py +++ b/tests/api_resources/agents/chat/test_completions.py @@ -19,7 +19,7 @@ class TestCompletions: @pytest.mark.skip() @parametrize - def test_method_create(self, client: GradientAI) -> None: + def test_method_create_overload_1(self, client: GradientAI) -> None: completion = client.agents.chat.completions.create( messages=[ { @@ -33,7 +33,7 @@ def test_method_create(self, client: GradientAI) -> None: @pytest.mark.skip() @parametrize - def test_method_create_with_all_params(self, client: GradientAI) -> None: + def test_method_create_with_all_params_overload_1(self, client: GradientAI) -> None: completion = client.agents.chat.completions.create( messages=[ { @@ -51,7 +51,7 @@ def test_method_create_with_all_params(self, client: GradientAI) -> None: n=1, presence_penalty=-2, stop="\n", - stream=True, + stream=False, stream_options={"include_usage": True}, temperature=1, top_logprobs=0, @@ -62,7 +62,7 @@ def test_method_create_with_all_params(self, client: GradientAI) -> None: @pytest.mark.skip() @parametrize - def test_raw_response_create(self, client: GradientAI) -> None: + def test_raw_response_create_overload_1(self, client: GradientAI) -> None: response = client.agents.chat.completions.with_raw_response.create( messages=[ { @@ -80,7 +80,7 @@ def test_raw_response_create(self, client: GradientAI) -> None: @pytest.mark.skip() @parametrize - def test_streaming_response_create(self, client: GradientAI) -> None: + def test_streaming_response_create_overload_1(self, client: GradientAI) -> None: with client.agents.chat.completions.with_streaming_response.create( messages=[ { @@ -98,6 +98,89 @@ def test_streaming_response_create(self, client: GradientAI) -> None: assert cast(Any, response.is_closed) is True + @pytest.mark.skip() + @parametrize + def test_method_create_overload_2(self, client: GradientAI) -> None: + completion_stream = client.agents.chat.completions.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + ) + completion_stream.response.close() + + @pytest.mark.skip() + @parametrize + def test_method_create_with_all_params_overload_2(self, client: GradientAI) -> None: + completion_stream = client.agents.chat.completions.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + frequency_penalty=-2, + logit_bias={"foo": 0}, + logprobs=True, + max_completion_tokens=256, + max_tokens=0, + metadata={"foo": "string"}, + n=1, + presence_penalty=-2, + stop="\n", + stream_options={"include_usage": True}, + temperature=1, + top_logprobs=0, + top_p=1, + user="user-1234", + ) + completion_stream.response.close() + + @pytest.mark.skip() + @parametrize + def test_raw_response_create_overload_2(self, client: GradientAI) -> None: + response = client.agents.chat.completions.with_raw_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = response.parse() + stream.close() + + @pytest.mark.skip() + @parametrize + def test_streaming_response_create_overload_2(self, client: GradientAI) -> None: + with client.agents.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = response.parse() + stream.close() + + assert cast(Any, response.is_closed) is True + class TestAsyncCompletions: parametrize = pytest.mark.parametrize( @@ -106,7 +189,7 @@ class TestAsyncCompletions: @pytest.mark.skip() @parametrize - async def test_method_create(self, async_client: AsyncGradientAI) -> None: + async def test_method_create_overload_1(self, async_client: AsyncGradientAI) -> None: completion = await async_client.agents.chat.completions.create( messages=[ { @@ -120,7 +203,7 @@ async def test_method_create(self, async_client: AsyncGradientAI) -> None: @pytest.mark.skip() @parametrize - async def test_method_create_with_all_params(self, async_client: AsyncGradientAI) -> None: + async def test_method_create_with_all_params_overload_1(self, async_client: AsyncGradientAI) -> None: completion = await async_client.agents.chat.completions.create( messages=[ { @@ -138,7 +221,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGradientAI n=1, presence_penalty=-2, stop="\n", - stream=True, + stream=False, stream_options={"include_usage": True}, temperature=1, top_logprobs=0, @@ -149,7 +232,7 @@ async def test_method_create_with_all_params(self, async_client: AsyncGradientAI @pytest.mark.skip() @parametrize - async def test_raw_response_create(self, async_client: AsyncGradientAI) -> None: + async def test_raw_response_create_overload_1(self, async_client: AsyncGradientAI) -> None: response = await async_client.agents.chat.completions.with_raw_response.create( messages=[ { @@ -167,7 +250,7 @@ async def test_raw_response_create(self, async_client: AsyncGradientAI) -> None: @pytest.mark.skip() @parametrize - async def test_streaming_response_create(self, async_client: AsyncGradientAI) -> None: + async def test_streaming_response_create_overload_1(self, async_client: AsyncGradientAI) -> None: async with async_client.agents.chat.completions.with_streaming_response.create( messages=[ { @@ -184,3 +267,86 @@ async def test_streaming_response_create(self, async_client: AsyncGradientAI) -> assert_matches_type(CompletionCreateResponse, completion, path=["response"]) assert cast(Any, response.is_closed) is True + + @pytest.mark.skip() + @parametrize + async def test_method_create_overload_2(self, async_client: AsyncGradientAI) -> None: + completion_stream = await async_client.agents.chat.completions.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + ) + await completion_stream.response.aclose() + + @pytest.mark.skip() + @parametrize + async def test_method_create_with_all_params_overload_2(self, async_client: AsyncGradientAI) -> None: + completion_stream = await async_client.agents.chat.completions.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + frequency_penalty=-2, + logit_bias={"foo": 0}, + logprobs=True, + max_completion_tokens=256, + max_tokens=0, + metadata={"foo": "string"}, + n=1, + presence_penalty=-2, + stop="\n", + stream_options={"include_usage": True}, + temperature=1, + top_logprobs=0, + top_p=1, + user="user-1234", + ) + await completion_stream.response.aclose() + + @pytest.mark.skip() + @parametrize + async def test_raw_response_create_overload_2(self, async_client: AsyncGradientAI) -> None: + response = await async_client.agents.chat.completions.with_raw_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + ) + + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + stream = await response.parse() + await stream.close() + + @pytest.mark.skip() + @parametrize + async def test_streaming_response_create_overload_2(self, async_client: AsyncGradientAI) -> None: + async with async_client.agents.chat.completions.with_streaming_response.create( + messages=[ + { + "content": "string", + "role": "system", + } + ], + model="llama3-8b-instruct", + stream=True, + ) as response: + assert not response.is_closed + assert response.http_request.headers.get("X-Stainless-Lang") == "python" + + stream = await response.parse() + await stream.close() + + assert cast(Any, response.is_closed) is True diff --git a/tests/test_client.py b/tests/test_client.py index fc2c1325..137fabed 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -24,6 +24,7 @@ from gradientai import GradientAI, AsyncGradientAI, APIResponseValidationError from gradientai._types import Omit from gradientai._models import BaseModel, FinalRequestOptions +from gradientai._streaming import Stream, AsyncStream from gradientai._exceptions import APIStatusError, APITimeoutError, APIResponseValidationError from gradientai._base_client import ( DEFAULT_TIMEOUT, @@ -751,6 +752,17 @@ def test_client_max_retries_validation(self) -> None: max_retries=cast(Any, None), ) + @pytest.mark.respx(base_url=base_url) + def test_default_stream_cls(self, respx_mock: MockRouter) -> None: + class Model(BaseModel): + name: str + + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + stream = self.client.post("/foo", cast_to=Model, stream=True, stream_cls=Stream[Model]) + assert isinstance(stream, Stream) + stream.response.close() + @pytest.mark.respx(base_url=base_url) def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: class Model(BaseModel): @@ -1650,6 +1662,18 @@ async def test_client_max_retries_validation(self) -> None: max_retries=cast(Any, None), ) + @pytest.mark.respx(base_url=base_url) + @pytest.mark.asyncio + async def test_default_stream_cls(self, respx_mock: MockRouter) -> None: + class Model(BaseModel): + name: str + + respx_mock.post("/foo").mock(return_value=httpx.Response(200, json={"foo": "bar"})) + + stream = await self.client.post("/foo", cast_to=Model, stream=True, stream_cls=AsyncStream[Model]) + assert isinstance(stream, AsyncStream) + await stream.response.aclose() + @pytest.mark.respx(base_url=base_url) @pytest.mark.asyncio async def test_received_text_for_expected_json(self, respx_mock: MockRouter) -> None: